zeroshade commented on code in PR #40496: URL: https://github.com/apache/arrow/pull/40496#discussion_r1588258204
########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" Review Comment: can we fix up the imports via gofmt? stdlib imports should be first, followed by an empty line, and then the rest ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} + +func (sm SuperMessage) Schema() *arrow.Schema { + var fields []arrow.Field + for _, sf := range sm.superFields { + fields = append(fields, sf.Field) + } + return arrow.NewSchema(fields, nil) +} + +func (sm SuperMessage) Record(mem memory.Allocator) arrow.Record { + if mem == nil { + mem = memory.NewGoAllocator() + } + + schema := sm.Schema() + + recordBuilder := array.NewRecordBuilder(mem, schema) + + var fieldNames []string + for i, sf := range sm.superFields { + sf.AppendValueOrNull(recordBuilder.Field(i), mem) + fieldNames = append(fieldNames, sf.protobufReflection.name()) + } + + var arrays []arrow.Array + for _, bldr := range recordBuilder.Fields() { + a := bldr.NewArray() + arrays = append(arrays, a) + } + + structArray, _ := array.NewStructArray(arrays, fieldNames) + + return array.RecordFromStructArray(structArray, schema) +} + +func NewSuperMessage(msg proto.Message, options ...option) *SuperMessage { + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + includeAll := func(pfr *protobufFieldReflection) bool { + return false + } + noFormatting := func(str string) string { + return str + } + psr := &ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: schemaOptions{ + exclusionPolicy: includeAll, + fieldNameFormatter: noFormatting, + oneOfHandler: Null, + }, + } + + for _, opt := range options { + opt(psr) + } + + var fields []SuperField + + for pfr := range psr.generateFields() { + fields = append(fields, SuperField{ + parent: psr, + protobufReflection: pfr, + Field: pfr.arrowField(), + }) + } + + return &SuperMessage{superFields: fields} +} + +type option func(*ProtobufStructReflection) + +// WithExclusionPolicy is an option for a ProtobufStructReflection +// WithExclusionPolicy acts as a deny filter on the fields of a protobuf message +// i.e. prevents them from being included in the schema. +// A use case for this is to exclude fields containing PII. +func WithExclusionPolicy(ex func(pfr *protobufFieldReflection) bool) option { + return func(psr *ProtobufStructReflection) { + psr.exclusionPolicy = ex + } +} + +// WithFieldNameFormatter is an option for a ProtobufStructReflection +// WithFieldNameFormatter enables customisation of the field names in the arrow schema +// By default, the field names are taken from the protobuf message (.proto file) +func WithFieldNameFormatter(formatter func(str string) string) option { + return func(psr *ProtobufStructReflection) { + psr.fieldNameFormatter = formatter + } +} + +// WithOneOfHandler is an option for a ProtobufStructReflection +// WithOneOfHandler enables customisation of the protobuf oneOf type in the arrow schema +// By default, the oneOfs are mapped to separate columns +func WithOneOfHandler(oneOfHandler OneOfHandler) option { + return func(psr *ProtobufStructReflection) { + psr.oneOfHandler = oneOfHandler + } +} + +func (psr ProtobufStructReflection) unmarshallAny() ProtobufStructReflection { + if psr.descriptor.FullName() == "google.protobuf.Any" && psr.rValue.IsValid() { + for psr.rValue.Type().Kind() == reflect.Ptr { + psr.rValue = reflect.Indirect(psr.rValue) + } + fieldValueAsAny, _ := psr.rValue.Interface().(anypb.Any) + msg, _ := fieldValueAsAny.UnmarshalNew() + + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = reflect.Indirect(v) + } + + return ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: psr.schemaOptions, + } + } else { + return psr + } +} + +func (psr ProtobufStructReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range psr.generateStructFields() { + fields = append(fields, arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + }) + } + + return fields +} + +type protobufListReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asList() protobufListReflection { + return protobufListReflection{*pfr} +} + +func (plr protobufListReflection) getDataType() arrow.DataType { + for li := range plr.generateListItems() { + return arrow.ListOf(li.getDataType()) + } + pfr := protobufFieldReflection{ + descriptor: plr.descriptor, + schemaOptions: plr.schemaOptions, + } + return arrow.ListOf(pfr.getDataType()) +} + +type protobufUnionReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asUnion() protobufUnionReflection { + return protobufUnionReflection{*pfr} +} + +func (pur protobufUnionReflection) isThisOne() bool { + for pur.rValue.Kind() == reflect.Ptr || pur.rValue.Kind() == reflect.Interface { + pur.rValue = pur.rValue.Elem() + } + return pur.rValue.Field(0).String() == pur.prValue.String() +} + +func (pur protobufUnionReflection) whichOne() arrow.UnionTypeCode { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pur.getUnionTypeCode(int32(pfr.descriptor.Number())) + } + } + // i.e. all null + return -1 +} + +func (pur protobufUnionReflection) getField() *protobufFieldReflection { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pfr + } + } + // i.e. all null + return nil +} + +func (pur protobufUnionReflection) getUnionTypeCode(n int32) arrow.UnionTypeCode { + //We use the index of the field number as there is a limit on the arrow.UnionTypeCode (127) + //which a protobuf Number could realistically exceed + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + if n == int32(fds.Get(i).Number()) { + return int8(i) + } + } + return -1 +} + +func (pur protobufUnionReflection) generateUnionFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + go func() { + defer close(out) + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + // Do not get stuck in a recursion loop + pfr.oneOfHandler = Null + if pfr.exclusionPolicy(pfr) { + continue + } + out <- pfr + } + }() + + return out +} + +func (pur protobufUnionReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range pur.generateUnionFields() { + fields = append(fields, pfr.arrowField()) + } + + return fields +} + +func (pur protobufUnionReflection) getDataType() arrow.DataType { + fds := pur.getArrowFields() + typeCodes := make([]arrow.UnionTypeCode, len(fds)) + for i := 0; i < len(fds); i++ { + typeCodes[i] = arrow.UnionTypeCode(i) + } + return arrow.DenseUnionOf(fds, typeCodes) +} + +type protobufDictReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asDictionary() protobufDictReflection { + return protobufDictReflection{*pfr} +} + +func (pdr protobufDictReflection) getDataType() arrow.DataType { + return &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, + ValueType: arrow.BinaryTypes.String, + Ordered: false, + } +} + +type protobufMapReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asMap() protobufMapReflection { + return protobufMapReflection{*pfr} +} + +func (pmr protobufMapReflection) getDataType() arrow.DataType { + for kvp := range pmr.generateKeyValuePairs() { + return kvp.getDataType() + } + return protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + schemaOptions: pmr.schemaOptions, + }, + }.getDataType() +} + +type protobufMapKeyValuePairReflection struct { + k protobufFieldReflection + v protobufFieldReflection +} + +func (pmr protobufMapKeyValuePairReflection) getDataType() arrow.DataType { + return arrow.MapOf(pmr.k.getDataType(), pmr.v.getDataType()) +} + +func (pmr protobufMapReflection) generateKeyValuePairs() chan protobufMapKeyValuePairReflection { + out := make(chan protobufMapKeyValuePairReflection) + + go func() { + defer close(out) + for _, k := range pmr.rValue.MapKeys() { + kvp := protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + prValue: getMapKey(k), + rValue: k, + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + prValue: pmr.prValue.Map().Get(protoreflect.MapKey(getMapKey(k))), + rValue: pmr.rValue.MapIndex(k), + schemaOptions: pmr.schemaOptions, + }, + } + out <- kvp + } + }() + + return out +} + +func getMapKey(v reflect.Value) protoreflect.Value { + switch v.Kind() { + case reflect.String: + return protoreflect.ValueOf(v.String()) + case reflect.Int32, reflect.Int64: + return protoreflect.ValueOf(v.Int()) + case reflect.Bool: + return protoreflect.ValueOf(v.Bool()) + case reflect.Uint32, reflect.Uint64: + return protoreflect.ValueOf(v.Uint()) + default: + panic("Unmapped protoreflect map key type") + } +} + +func (psr ProtobufStructReflection) generateStructFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (psr ProtobufStructReflection) generateFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (pfr *protobufFieldReflection) asStruct() ProtobufStructReflection { + psr := ProtobufStructReflection{ + descriptor: pfr.descriptor.Message(), + rValue: pfr.rValue, + schemaOptions: pfr.schemaOptions, + } + if pfr.prValue.IsValid() { + psr.message = pfr.prValue.Message() + } + psr = psr.unmarshallAny() + return psr +} + +func (psr ProtobufStructReflection) getDataType() arrow.DataType { + return arrow.StructOf(psr.getArrowFields()...) +} + +func (psr ProtobufStructReflection) getFieldByName(n string) *protobufFieldReflection { + fd := psr.descriptor.Fields().ByTextName(xstrings.ToSnakeCase(n)) + fv := psr.rValue + if fv.IsValid() { + if !fv.IsZero() { + for fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Interface { + fv = fv.Elem() + } + if fd.ContainingOneof() != nil { + n = string(fd.ContainingOneof().Name()) + } + fv = fv.FieldByName(xstrings.ToCamelCase(n)) + for fv.Kind() == reflect.Ptr { + fv = fv.Elem() + } + } + } + pfr := protobufFieldReflection{ + parent: &psr, + descriptor: fd, + rValue: fv, + schemaOptions: psr.schemaOptions, + } + if psr.message != nil { + pfr.prValue = psr.message.Get(fd) + } + return &pfr +} + +type protobufFieldReflection struct { + parent *ProtobufStructReflection + descriptor protoreflect.FieldDescriptor + prValue protoreflect.Value + rValue reflect.Value + schemaOptions +} + +func (pfr *protobufFieldReflection) isNull() bool { + for pfr.rValue.Kind() == reflect.Ptr { + if pfr.rValue.IsNil() { + return true + } + pfr.rValue = pfr.rValue.Elem() + } + + if !pfr.rValue.IsValid() && !pfr.prValue.IsValid() { + return true + } + return false +} + +func (pfr *protobufFieldReflection) arrowField() arrow.Field { + return arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + } +} + +func (pfr *protobufFieldReflection) protoreflectValue() protoreflect.Value { + return pfr.prValue +} + +func (pfr *protobufFieldReflection) reflectValue() reflect.Value { + return pfr.rValue +} + +func (pfr *protobufFieldReflection) getDescriptor() protoreflect.FieldDescriptor { + return pfr.descriptor +} + +func (pfr *protobufFieldReflection) name() string { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler != Null { + return pfr.fieldNameFormatter(string(pfr.descriptor.ContainingOneof().Name())) + } + return pfr.fieldNameFormatter(string(pfr.descriptor.Name())) +} + +func (pfr *protobufFieldReflection) arrowType() arrow.Type { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler == DenseUnion { + return arrow.DENSE_UNION + } + if pfr.isEnum() { + return arrow.DICTIONARY + } + if pfr.isStruct() { + return arrow.STRUCT + } + if pfr.isMap() { + return arrow.MAP + } + if pfr.isList() { + return arrow.LIST + } + return arrow.NULL +} + +func (pfr *protobufFieldReflection) isOneOf() bool { + return pfr.descriptor.ContainingOneof() != nil +} + +func (pfr *protobufFieldReflection) isEnum() bool { + return pfr.descriptor.Kind() == protoreflect.EnumKind +} + +func (pfr *protobufFieldReflection) isStruct() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && pfr.rValue.Kind() != reflect.Slice +} + +func (pfr *protobufFieldReflection) isMap() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && pfr.descriptor.IsMap() +} + +func (pfr *protobufFieldReflection) isList() bool { + return pfr.descriptor.IsList() && pfr.rValue.Kind() == reflect.Slice +} + +func (plr protobufListReflection) generateListItems() chan protobufFieldReflection { + out := make(chan protobufFieldReflection) + + go func() { + defer close(out) + for i := 0; i < plr.prValue.List().Len(); i++ { + out <- protobufFieldReflection{ + descriptor: plr.descriptor, + prValue: plr.prValue.List().Get(i), + rValue: plr.rValue.Index(i), + schemaOptions: plr.schemaOptions, + } + } + }() + + return out +} + +func (pfr *protobufFieldReflection) getDataType() arrow.DataType { + var dt arrow.DataType + + typeMap := map[protoreflect.Kind]arrow.DataType{ + //Numeric + protoreflect.Int32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Int64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Sint32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sint64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Uint32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Uint64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Fixed32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Fixed64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Sfixed32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sfixed64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.FloatKind: arrow.PrimitiveTypes.Float32, + protoreflect.DoubleKind: arrow.PrimitiveTypes.Float64, + //Binary + protoreflect.StringKind: arrow.BinaryTypes.String, + protoreflect.BytesKind: arrow.BinaryTypes.Binary, + //Fixed Width + protoreflect.BoolKind: arrow.FixedWidthTypes.Boolean, + // Special + protoreflect.EnumKind: nil, + protoreflect.MessageKind: nil, + } + dt = typeMap[pfr.descriptor.Kind()] + + switch pfr.arrowType() { + case arrow.DENSE_UNION: + dt = pfr.asUnion().getDataType() + case arrow.DICTIONARY: + dt = pfr.asDictionary().getDataType() + case arrow.LIST: + dt = pfr.asList().getDataType() + case arrow.MAP: + dt = pfr.asMap().getDataType() + case arrow.STRUCT: + dt = pfr.asStruct().getDataType() + } + + return dt +} + +func (sf SuperField) AppendValueOrNull(b array.Builder, mem memory.Allocator) { Review Comment: docstring and can we put this closer to the other `SuperField` methods? ########## go/arrow/util/protobuf_reflect_test.go: ########## @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/apache/arrow/go/v17/arrow/util/util_message" + "github.com/huandu/xstrings" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "strings" + "testing" Review Comment: fix import formatting ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} + +func (sm SuperMessage) Schema() *arrow.Schema { Review Comment: doc string ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} Review Comment: doc strings? also can we have a better name than `SuperMessage` :smile: ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int Review Comment: Add a docstring for this please ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} + +func (sm SuperMessage) Schema() *arrow.Schema { + var fields []arrow.Field + for _, sf := range sm.superFields { + fields = append(fields, sf.Field) + } + return arrow.NewSchema(fields, nil) +} + +func (sm SuperMessage) Record(mem memory.Allocator) arrow.Record { Review Comment: doc string ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} + +func (sm SuperMessage) Schema() *arrow.Schema { + var fields []arrow.Field + for _, sf := range sm.superFields { + fields = append(fields, sf.Field) + } + return arrow.NewSchema(fields, nil) +} + +func (sm SuperMessage) Record(mem memory.Allocator) arrow.Record { + if mem == nil { + mem = memory.NewGoAllocator() + } + + schema := sm.Schema() + + recordBuilder := array.NewRecordBuilder(mem, schema) + + var fieldNames []string + for i, sf := range sm.superFields { + sf.AppendValueOrNull(recordBuilder.Field(i), mem) + fieldNames = append(fieldNames, sf.protobufReflection.name()) + } + + var arrays []arrow.Array + for _, bldr := range recordBuilder.Fields() { + a := bldr.NewArray() + arrays = append(arrays, a) + } + + structArray, _ := array.NewStructArray(arrays, fieldNames) + + return array.RecordFromStructArray(structArray, schema) +} + +func NewSuperMessage(msg proto.Message, options ...option) *SuperMessage { Review Comment: doc string ########## go/arrow/util/protobuf_reflect_test.go: ########## @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/apache/arrow/go/v17/arrow/util/util_message" + "github.com/huandu/xstrings" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "strings" + "testing" +) + +func SetupTest() util_message.AllTheTypes { + msg := util_message.ExampleMessage{ + Field1: "Example", + } + + anyMsg, _ := anypb.New(&msg) + + return util_message.AllTheTypes{ + String_: "Hello", + Int32: 10, + Int64: 100, + Sint32: -10, + Sin64: -100, + Uint32: 10, + Uint64: 100, + Fixed32: 10, + Fixed64: 1000, + Sfixed32: 10, + Bool: false, + Bytes: []byte("Hello, world!"), + Double: 1.1, + Enum: util_message.AllTheTypes_OPTION_0, + Message: &msg, + Oneof: &util_message.AllTheTypes_Oneofstring{Oneofstring: "World"}, + Any: anyMsg, + //Breaks the test as the `RecordFromJson` randomises the order + //SimpleMap: map[int32]string{99: "Hello", 100: "World", 98: "How", 101: "Are", 1: "You"}, Review Comment: RecordFromJSON doesn't randomize the order, Golang maps have a non-deterministic order. ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} + +func (sm SuperMessage) Schema() *arrow.Schema { + var fields []arrow.Field + for _, sf := range sm.superFields { + fields = append(fields, sf.Field) + } + return arrow.NewSchema(fields, nil) +} + +func (sm SuperMessage) Record(mem memory.Allocator) arrow.Record { + if mem == nil { + mem = memory.NewGoAllocator() + } + + schema := sm.Schema() + + recordBuilder := array.NewRecordBuilder(mem, schema) + + var fieldNames []string + for i, sf := range sm.superFields { + sf.AppendValueOrNull(recordBuilder.Field(i), mem) + fieldNames = append(fieldNames, sf.protobufReflection.name()) + } + + var arrays []arrow.Array + for _, bldr := range recordBuilder.Fields() { + a := bldr.NewArray() + arrays = append(arrays, a) + } + + structArray, _ := array.NewStructArray(arrays, fieldNames) + + return array.RecordFromStructArray(structArray, schema) +} + +func NewSuperMessage(msg proto.Message, options ...option) *SuperMessage { + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + includeAll := func(pfr *protobufFieldReflection) bool { + return false + } + noFormatting := func(str string) string { + return str + } + psr := &ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: schemaOptions{ + exclusionPolicy: includeAll, + fieldNameFormatter: noFormatting, + oneOfHandler: Null, + }, + } + + for _, opt := range options { + opt(psr) + } + + var fields []SuperField + + for pfr := range psr.generateFields() { + fields = append(fields, SuperField{ + parent: psr, + protobufReflection: pfr, + Field: pfr.arrowField(), + }) + } + + return &SuperMessage{superFields: fields} +} + +type option func(*ProtobufStructReflection) + +// WithExclusionPolicy is an option for a ProtobufStructReflection +// WithExclusionPolicy acts as a deny filter on the fields of a protobuf message +// i.e. prevents them from being included in the schema. +// A use case for this is to exclude fields containing PII. +func WithExclusionPolicy(ex func(pfr *protobufFieldReflection) bool) option { + return func(psr *ProtobufStructReflection) { + psr.exclusionPolicy = ex + } +} + +// WithFieldNameFormatter is an option for a ProtobufStructReflection +// WithFieldNameFormatter enables customisation of the field names in the arrow schema +// By default, the field names are taken from the protobuf message (.proto file) +func WithFieldNameFormatter(formatter func(str string) string) option { + return func(psr *ProtobufStructReflection) { + psr.fieldNameFormatter = formatter + } +} + +// WithOneOfHandler is an option for a ProtobufStructReflection +// WithOneOfHandler enables customisation of the protobuf oneOf type in the arrow schema +// By default, the oneOfs are mapped to separate columns +func WithOneOfHandler(oneOfHandler OneOfHandler) option { + return func(psr *ProtobufStructReflection) { + psr.oneOfHandler = oneOfHandler + } +} + +func (psr ProtobufStructReflection) unmarshallAny() ProtobufStructReflection { + if psr.descriptor.FullName() == "google.protobuf.Any" && psr.rValue.IsValid() { + for psr.rValue.Type().Kind() == reflect.Ptr { + psr.rValue = reflect.Indirect(psr.rValue) + } + fieldValueAsAny, _ := psr.rValue.Interface().(anypb.Any) + msg, _ := fieldValueAsAny.UnmarshalNew() + + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = reflect.Indirect(v) + } + + return ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: psr.schemaOptions, + } + } else { + return psr + } +} + +func (psr ProtobufStructReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range psr.generateStructFields() { + fields = append(fields, arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + }) + } + + return fields +} + +type protobufListReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asList() protobufListReflection { + return protobufListReflection{*pfr} +} + +func (plr protobufListReflection) getDataType() arrow.DataType { + for li := range plr.generateListItems() { + return arrow.ListOf(li.getDataType()) + } + pfr := protobufFieldReflection{ + descriptor: plr.descriptor, + schemaOptions: plr.schemaOptions, + } + return arrow.ListOf(pfr.getDataType()) +} + +type protobufUnionReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asUnion() protobufUnionReflection { + return protobufUnionReflection{*pfr} +} + +func (pur protobufUnionReflection) isThisOne() bool { + for pur.rValue.Kind() == reflect.Ptr || pur.rValue.Kind() == reflect.Interface { + pur.rValue = pur.rValue.Elem() + } + return pur.rValue.Field(0).String() == pur.prValue.String() +} + +func (pur protobufUnionReflection) whichOne() arrow.UnionTypeCode { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pur.getUnionTypeCode(int32(pfr.descriptor.Number())) + } + } + // i.e. all null + return -1 +} + +func (pur protobufUnionReflection) getField() *protobufFieldReflection { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pfr + } + } + // i.e. all null + return nil +} + +func (pur protobufUnionReflection) getUnionTypeCode(n int32) arrow.UnionTypeCode { + //We use the index of the field number as there is a limit on the arrow.UnionTypeCode (127) + //which a protobuf Number could realistically exceed + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + if n == int32(fds.Get(i).Number()) { + return int8(i) + } + } + return -1 +} + +func (pur protobufUnionReflection) generateUnionFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + go func() { + defer close(out) + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + // Do not get stuck in a recursion loop + pfr.oneOfHandler = Null + if pfr.exclusionPolicy(pfr) { + continue + } + out <- pfr + } + }() + + return out +} + +func (pur protobufUnionReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range pur.generateUnionFields() { + fields = append(fields, pfr.arrowField()) + } + + return fields +} + +func (pur protobufUnionReflection) getDataType() arrow.DataType { + fds := pur.getArrowFields() + typeCodes := make([]arrow.UnionTypeCode, len(fds)) + for i := 0; i < len(fds); i++ { + typeCodes[i] = arrow.UnionTypeCode(i) + } + return arrow.DenseUnionOf(fds, typeCodes) +} + +type protobufDictReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asDictionary() protobufDictReflection { + return protobufDictReflection{*pfr} +} + +func (pdr protobufDictReflection) getDataType() arrow.DataType { + return &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, + ValueType: arrow.BinaryTypes.String, + Ordered: false, + } +} + +type protobufMapReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asMap() protobufMapReflection { + return protobufMapReflection{*pfr} +} + +func (pmr protobufMapReflection) getDataType() arrow.DataType { + for kvp := range pmr.generateKeyValuePairs() { + return kvp.getDataType() + } + return protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + schemaOptions: pmr.schemaOptions, + }, + }.getDataType() +} + +type protobufMapKeyValuePairReflection struct { + k protobufFieldReflection + v protobufFieldReflection +} + +func (pmr protobufMapKeyValuePairReflection) getDataType() arrow.DataType { + return arrow.MapOf(pmr.k.getDataType(), pmr.v.getDataType()) +} + +func (pmr protobufMapReflection) generateKeyValuePairs() chan protobufMapKeyValuePairReflection { + out := make(chan protobufMapKeyValuePairReflection) + + go func() { + defer close(out) + for _, k := range pmr.rValue.MapKeys() { + kvp := protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + prValue: getMapKey(k), + rValue: k, + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + prValue: pmr.prValue.Map().Get(protoreflect.MapKey(getMapKey(k))), + rValue: pmr.rValue.MapIndex(k), + schemaOptions: pmr.schemaOptions, + }, + } + out <- kvp + } + }() + + return out +} + +func getMapKey(v reflect.Value) protoreflect.Value { + switch v.Kind() { + case reflect.String: + return protoreflect.ValueOf(v.String()) + case reflect.Int32, reflect.Int64: + return protoreflect.ValueOf(v.Int()) + case reflect.Bool: + return protoreflect.ValueOf(v.Bool()) + case reflect.Uint32, reflect.Uint64: + return protoreflect.ValueOf(v.Uint()) + default: + panic("Unmapped protoreflect map key type") + } +} + +func (psr ProtobufStructReflection) generateStructFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (psr ProtobufStructReflection) generateFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (pfr *protobufFieldReflection) asStruct() ProtobufStructReflection { + psr := ProtobufStructReflection{ + descriptor: pfr.descriptor.Message(), + rValue: pfr.rValue, + schemaOptions: pfr.schemaOptions, + } + if pfr.prValue.IsValid() { + psr.message = pfr.prValue.Message() + } + psr = psr.unmarshallAny() + return psr +} + +func (psr ProtobufStructReflection) getDataType() arrow.DataType { + return arrow.StructOf(psr.getArrowFields()...) +} + +func (psr ProtobufStructReflection) getFieldByName(n string) *protobufFieldReflection { + fd := psr.descriptor.Fields().ByTextName(xstrings.ToSnakeCase(n)) + fv := psr.rValue + if fv.IsValid() { + if !fv.IsZero() { + for fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Interface { + fv = fv.Elem() + } + if fd.ContainingOneof() != nil { + n = string(fd.ContainingOneof().Name()) + } + fv = fv.FieldByName(xstrings.ToCamelCase(n)) + for fv.Kind() == reflect.Ptr { + fv = fv.Elem() + } + } + } + pfr := protobufFieldReflection{ + parent: &psr, + descriptor: fd, + rValue: fv, + schemaOptions: psr.schemaOptions, + } + if psr.message != nil { + pfr.prValue = psr.message.Get(fd) + } + return &pfr +} + +type protobufFieldReflection struct { + parent *ProtobufStructReflection + descriptor protoreflect.FieldDescriptor + prValue protoreflect.Value + rValue reflect.Value + schemaOptions +} + +func (pfr *protobufFieldReflection) isNull() bool { + for pfr.rValue.Kind() == reflect.Ptr { + if pfr.rValue.IsNil() { + return true + } + pfr.rValue = pfr.rValue.Elem() + } + + if !pfr.rValue.IsValid() && !pfr.prValue.IsValid() { + return true + } + return false +} + +func (pfr *protobufFieldReflection) arrowField() arrow.Field { + return arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + } +} + +func (pfr *protobufFieldReflection) protoreflectValue() protoreflect.Value { + return pfr.prValue +} + +func (pfr *protobufFieldReflection) reflectValue() reflect.Value { + return pfr.rValue +} + +func (pfr *protobufFieldReflection) getDescriptor() protoreflect.FieldDescriptor { + return pfr.descriptor +} + +func (pfr *protobufFieldReflection) name() string { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler != Null { + return pfr.fieldNameFormatter(string(pfr.descriptor.ContainingOneof().Name())) + } + return pfr.fieldNameFormatter(string(pfr.descriptor.Name())) +} + +func (pfr *protobufFieldReflection) arrowType() arrow.Type { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler == DenseUnion { + return arrow.DENSE_UNION + } + if pfr.isEnum() { + return arrow.DICTIONARY + } + if pfr.isStruct() { + return arrow.STRUCT + } + if pfr.isMap() { + return arrow.MAP + } + if pfr.isList() { + return arrow.LIST + } + return arrow.NULL +} + +func (pfr *protobufFieldReflection) isOneOf() bool { + return pfr.descriptor.ContainingOneof() != nil +} + +func (pfr *protobufFieldReflection) isEnum() bool { + return pfr.descriptor.Kind() == protoreflect.EnumKind +} + +func (pfr *protobufFieldReflection) isStruct() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && pfr.rValue.Kind() != reflect.Slice +} + +func (pfr *protobufFieldReflection) isMap() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && pfr.descriptor.IsMap() +} + +func (pfr *protobufFieldReflection) isList() bool { + return pfr.descriptor.IsList() && pfr.rValue.Kind() == reflect.Slice +} + +func (plr protobufListReflection) generateListItems() chan protobufFieldReflection { + out := make(chan protobufFieldReflection) + + go func() { + defer close(out) + for i := 0; i < plr.prValue.List().Len(); i++ { + out <- protobufFieldReflection{ + descriptor: plr.descriptor, + prValue: plr.prValue.List().Get(i), + rValue: plr.rValue.Index(i), + schemaOptions: plr.schemaOptions, + } + } + }() + + return out +} + +func (pfr *protobufFieldReflection) getDataType() arrow.DataType { + var dt arrow.DataType + + typeMap := map[protoreflect.Kind]arrow.DataType{ + //Numeric + protoreflect.Int32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Int64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Sint32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sint64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Uint32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Uint64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Fixed32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Fixed64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Sfixed32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sfixed64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.FloatKind: arrow.PrimitiveTypes.Float32, + protoreflect.DoubleKind: arrow.PrimitiveTypes.Float64, + //Binary + protoreflect.StringKind: arrow.BinaryTypes.String, + protoreflect.BytesKind: arrow.BinaryTypes.Binary, + //Fixed Width + protoreflect.BoolKind: arrow.FixedWidthTypes.Boolean, + // Special + protoreflect.EnumKind: nil, + protoreflect.MessageKind: nil, + } + dt = typeMap[pfr.descriptor.Kind()] + + switch pfr.arrowType() { + case arrow.DENSE_UNION: + dt = pfr.asUnion().getDataType() + case arrow.DICTIONARY: + dt = pfr.asDictionary().getDataType() + case arrow.LIST: + dt = pfr.asList().getDataType() + case arrow.MAP: + dt = pfr.asMap().getDataType() + case arrow.STRUCT: + dt = pfr.asStruct().getDataType() + } + + return dt +} + +func (sf SuperField) AppendValueOrNull(b array.Builder, mem memory.Allocator) { + pv := sf.protoreflectValue() + fd := sf.getDescriptor() + + if sf.isNull() { + b.AppendNull() + return + } + + switch b.Type().ID() { + case arrow.STRING: + b.(*array.StringBuilder).Append(pv.String()) + case arrow.BINARY: + b.(*array.BinaryBuilder).Append(pv.Bytes()) + case arrow.INT32: + b.(*array.Int32Builder).Append(int32(pv.Int())) + case arrow.INT64: + b.(*array.Int64Builder).Append(pv.Int()) + case arrow.FLOAT64: + b.(*array.Float64Builder).Append(pv.Float()) + case arrow.UINT32: + b.(*array.Uint32Builder).Append(uint32(pv.Uint())) + case arrow.UINT64: + b.(*array.Uint64Builder).Append(pv.Uint()) + case arrow.BOOL: + b.(*array.BooleanBuilder).Append(pv.Bool()) + case arrow.DENSE_UNION: + ub := b.(array.UnionBuilder) + pur := sf.asUnion() + if pur.whichOne() == -1 { + ub.AppendNull() + break + } + ub.Append(pur.whichOne()) + cb := ub.Child(int(pur.whichOne())) + SuperField{ + parent: sf.parent, + protobufReflection: pur.getField(), + Field: pur.arrowField(), + }.AppendValueOrNull(cb, mem) + case arrow.DICTIONARY: + db := b.(array.DictionaryBuilder) + err := db.AppendValueFromString(string(fd.Enum().Values().ByNumber(pv.Enum()).Name())) + if err != nil { + fmt.Println(err) + } Review Comment: return the error ########## go/arrow/util/protobuf_reflect_test.go: ########## @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/apache/arrow/go/v17/arrow/util/util_message" + "github.com/huandu/xstrings" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "strings" + "testing" +) + +func SetupTest() util_message.AllTheTypes { + msg := util_message.ExampleMessage{ + Field1: "Example", + } + + anyMsg, _ := anypb.New(&msg) + + return util_message.AllTheTypes{ + String_: "Hello", + Int32: 10, + Int64: 100, + Sint32: -10, + Sin64: -100, + Uint32: 10, + Uint64: 100, + Fixed32: 10, + Fixed64: 1000, + Sfixed32: 10, + Bool: false, + Bytes: []byte("Hello, world!"), + Double: 1.1, + Enum: util_message.AllTheTypes_OPTION_0, + Message: &msg, + Oneof: &util_message.AllTheTypes_Oneofstring{Oneofstring: "World"}, + Any: anyMsg, + //Breaks the test as the `RecordFromJson` randomises the order + //SimpleMap: map[int32]string{99: "Hello", 100: "World", 98: "How", 101: "Are", 1: "You"}, + SimpleMap: map[int32]string{99: "Hello"}, + ComplexMap: map[string]*util_message.ExampleMessage{"complex": &msg}, + SimpleList: []string{"Hello", "World"}, + ComplexList: []*util_message.ExampleMessage{&msg}, + } +} + +func TestGetSchema(t *testing.T) { + msg := SetupTest() + + got := NewSuperMessage(&msg).Schema().String() + want := `schema: + fields: 22 + - string: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - message: type=struct<field1: utf8>, nullable + - oneofstring: type=utf8, nullable + - oneofmessage: type=struct<field1: utf8>, nullable + - any: type=struct<field1: utf8>, nullable + - simple_map: type=map<int32, utf8, items_nullable>, nullable + - complex_map: type=map<utf8, struct<field1: utf8>, items_nullable>, nullable + - simple_list: type=list<item: utf8, nullable>, nullable + - complex_list: type=list<item: struct<field1: utf8>, nullable>, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + + got = NewSuperMessage(&msg, WithOneOfHandler(DenseUnion)).Schema().String() + want = `schema: + fields: 21 + - string: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - message: type=struct<field1: utf8>, nullable + - oneof: type=dense_union<oneofstring: type=utf8, nullable=0, oneofmessage: type=struct<field1: utf8>, nullable=1>, nullable + - any: type=struct<field1: utf8>, nullable + - simple_map: type=map<int32, utf8, items_nullable>, nullable + - complex_map: type=map<utf8, struct<field1: utf8>, items_nullable>, nullable + - simple_list: type=list<item: utf8, nullable>, nullable + - complex_list: type=list<item: struct<field1: utf8>, nullable>, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + + excludeComplex := func(pfr *protobufFieldReflection) bool { + return pfr.isMap() || pfr.isList() || pfr.isStruct() + } + + got = NewSuperMessage(&msg, WithExclusionPolicy(excludeComplex)).Schema().String() + want = `schema: + fields: 15 + - string: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - oneofstring: type=utf8, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + + got = NewSuperMessage( + &msg, + WithExclusionPolicy(excludeComplex), + WithFieldNameFormatter(xstrings.ToCamelCase), + ).Schema().String() + want = `schema: + fields: 15 + - String: type=utf8, nullable + - Int32: type=int32, nullable + - Int64: type=int64, nullable + - Sint32: type=int32, nullable + - Sin64: type=int64, nullable + - Uint32: type=uint32, nullable + - Uint64: type=uint64, nullable + - Fixed32: type=uint32, nullable + - Fixed64: type=uint64, nullable + - Sfixed32: type=int32, nullable + - Bool: type=bool, nullable + - Bytes: type=binary, nullable + - Double: type=float64, nullable + - Enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - Oneofstring: type=utf8, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) +} + +func TestRecordFromProtobuf(t *testing.T) { + msg := SetupTest() + + sm := NewSuperMessage(&msg, WithOneOfHandler(DenseUnion)) + schema := sm.Schema() + got := sm.Record(nil) + jsonStr := `[ + { + "string":"Hello", + "int32":10, + "int64":100, + "sint32":-10, + "sin64":-100, + "uint32":10, + "uint64":100, + "fixed32":10, + "fixed64":1000, + "sfixed32":10, + "bool":false, + "bytes":"SGVsbG8sIHdvcmxkIQ==", + "double":1.1, + "enum":"OPTION_0", + "message":{"field1":"Example"}, + "oneof": [0, "World"], + "any":{"field1":"Example"}, + "simple_map":[{"key":99,"value":"Hello"}], + "complex_map":[{"key":"complex","value":{"field1":"Example"}}], + "simple_list":["Hello","World"], + "complex_list":[{"field1":"Example"}] + } + ]` + want, _, err := array.RecordFromJSON(memory.NewGoAllocator(), schema, strings.NewReader(jsonStr)) + + require.NoError(t, err) + require.True(t, array.RecordEqual(got, want), "got: %s\nwant: %s", got, want) + + sm = NewSuperMessage(&util_message.AllTheTypes{}) + schema = sm.Schema() + got = sm.Record(nil) + jsonStr = `[ + { + "string":"", + "int32":0, + "int64":0, + "sint32":0, + "sin64":0, + "uint32":0, + "uint64":0, + "fixed32":0, + "fixed64":0, + "sfixed32":0, + "bool":false, + "bytes":"", + "double":0, + "enum":"OPTION_0", + "message":{"field1":""}, + "oneofmessage":{"field1":""}, + "oneofstring":"", + "oneof": [0, "World"], + "any":{"type_url":"","value":""}, + "simple_map":[], + "complex_map":[], + "simple_list":[], + "complex_list":[] + } + ]` + + //jb, _ := got.MarshalJSON() + //fmt.Println(string(jb)) + Review Comment: remove? ########## go/arrow/util/protobuf_reflect_test.go: ########## @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/apache/arrow/go/v17/arrow/util/util_message" + "github.com/huandu/xstrings" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "strings" + "testing" +) + +func SetupTest() util_message.AllTheTypes { + msg := util_message.ExampleMessage{ + Field1: "Example", + } + + anyMsg, _ := anypb.New(&msg) + + return util_message.AllTheTypes{ + String_: "Hello", + Int32: 10, + Int64: 100, + Sint32: -10, + Sin64: -100, + Uint32: 10, + Uint64: 100, + Fixed32: 10, + Fixed64: 1000, + Sfixed32: 10, + Bool: false, + Bytes: []byte("Hello, world!"), + Double: 1.1, + Enum: util_message.AllTheTypes_OPTION_0, + Message: &msg, + Oneof: &util_message.AllTheTypes_Oneofstring{Oneofstring: "World"}, + Any: anyMsg, + //Breaks the test as the `RecordFromJson` randomises the order + //SimpleMap: map[int32]string{99: "Hello", 100: "World", 98: "How", 101: "Are", 1: "You"}, + SimpleMap: map[int32]string{99: "Hello"}, + ComplexMap: map[string]*util_message.ExampleMessage{"complex": &msg}, + SimpleList: []string{"Hello", "World"}, + ComplexList: []*util_message.ExampleMessage{&msg}, + } +} + +func TestGetSchema(t *testing.T) { + msg := SetupTest() + + got := NewSuperMessage(&msg).Schema().String() + want := `schema: + fields: 22 + - string: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - message: type=struct<field1: utf8>, nullable + - oneofstring: type=utf8, nullable + - oneofmessage: type=struct<field1: utf8>, nullable + - any: type=struct<field1: utf8>, nullable + - simple_map: type=map<int32, utf8, items_nullable>, nullable + - complex_map: type=map<utf8, struct<field1: utf8>, items_nullable>, nullable + - simple_list: type=list<item: utf8, nullable>, nullable + - complex_list: type=list<item: struct<field1: utf8>, nullable>, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + + got = NewSuperMessage(&msg, WithOneOfHandler(DenseUnion)).Schema().String() + want = `schema: + fields: 21 + - string: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - message: type=struct<field1: utf8>, nullable + - oneof: type=dense_union<oneofstring: type=utf8, nullable=0, oneofmessage: type=struct<field1: utf8>, nullable=1>, nullable + - any: type=struct<field1: utf8>, nullable + - simple_map: type=map<int32, utf8, items_nullable>, nullable + - complex_map: type=map<utf8, struct<field1: utf8>, items_nullable>, nullable + - simple_list: type=list<item: utf8, nullable>, nullable + - complex_list: type=list<item: struct<field1: utf8>, nullable>, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + + excludeComplex := func(pfr *protobufFieldReflection) bool { + return pfr.isMap() || pfr.isList() || pfr.isStruct() + } + + got = NewSuperMessage(&msg, WithExclusionPolicy(excludeComplex)).Schema().String() + want = `schema: + fields: 15 + - string: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - oneofstring: type=utf8, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + + got = NewSuperMessage( + &msg, + WithExclusionPolicy(excludeComplex), + WithFieldNameFormatter(xstrings.ToCamelCase), + ).Schema().String() + want = `schema: + fields: 15 + - String: type=utf8, nullable + - Int32: type=int32, nullable + - Int64: type=int64, nullable + - Sint32: type=int32, nullable + - Sin64: type=int64, nullable + - Uint32: type=uint32, nullable + - Uint64: type=uint64, nullable + - Fixed32: type=uint32, nullable + - Fixed64: type=uint64, nullable + - Sfixed32: type=int32, nullable + - Bool: type=bool, nullable + - Bytes: type=binary, nullable + - Double: type=float64, nullable + - Enum: type=dictionary<values=utf8, indices=int32, ordered=false>, nullable + - Oneofstring: type=utf8, nullable` + + require.Equal(t, want, got, "got: %s\nwant: %s", got, want) +} + +func TestRecordFromProtobuf(t *testing.T) { Review Comment: can we do a round trip test? ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} + +func (sm SuperMessage) Schema() *arrow.Schema { + var fields []arrow.Field + for _, sf := range sm.superFields { + fields = append(fields, sf.Field) + } + return arrow.NewSchema(fields, nil) +} + +func (sm SuperMessage) Record(mem memory.Allocator) arrow.Record { + if mem == nil { + mem = memory.NewGoAllocator() + } + + schema := sm.Schema() + + recordBuilder := array.NewRecordBuilder(mem, schema) + + var fieldNames []string + for i, sf := range sm.superFields { + sf.AppendValueOrNull(recordBuilder.Field(i), mem) + fieldNames = append(fieldNames, sf.protobufReflection.name()) + } + + var arrays []arrow.Array + for _, bldr := range recordBuilder.Fields() { + a := bldr.NewArray() + arrays = append(arrays, a) + } + + structArray, _ := array.NewStructArray(arrays, fieldNames) + + return array.RecordFromStructArray(structArray, schema) +} + +func NewSuperMessage(msg proto.Message, options ...option) *SuperMessage { + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + includeAll := func(pfr *protobufFieldReflection) bool { + return false + } + noFormatting := func(str string) string { + return str + } + psr := &ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: schemaOptions{ + exclusionPolicy: includeAll, + fieldNameFormatter: noFormatting, + oneOfHandler: Null, + }, + } + + for _, opt := range options { + opt(psr) + } + + var fields []SuperField + + for pfr := range psr.generateFields() { + fields = append(fields, SuperField{ + parent: psr, + protobufReflection: pfr, + Field: pfr.arrowField(), + }) + } + + return &SuperMessage{superFields: fields} +} + +type option func(*ProtobufStructReflection) + +// WithExclusionPolicy is an option for a ProtobufStructReflection +// WithExclusionPolicy acts as a deny filter on the fields of a protobuf message +// i.e. prevents them from being included in the schema. +// A use case for this is to exclude fields containing PII. +func WithExclusionPolicy(ex func(pfr *protobufFieldReflection) bool) option { + return func(psr *ProtobufStructReflection) { + psr.exclusionPolicy = ex + } +} + +// WithFieldNameFormatter is an option for a ProtobufStructReflection +// WithFieldNameFormatter enables customisation of the field names in the arrow schema +// By default, the field names are taken from the protobuf message (.proto file) +func WithFieldNameFormatter(formatter func(str string) string) option { + return func(psr *ProtobufStructReflection) { + psr.fieldNameFormatter = formatter + } +} + +// WithOneOfHandler is an option for a ProtobufStructReflection +// WithOneOfHandler enables customisation of the protobuf oneOf type in the arrow schema +// By default, the oneOfs are mapped to separate columns +func WithOneOfHandler(oneOfHandler OneOfHandler) option { + return func(psr *ProtobufStructReflection) { + psr.oneOfHandler = oneOfHandler + } +} + +func (psr ProtobufStructReflection) unmarshallAny() ProtobufStructReflection { + if psr.descriptor.FullName() == "google.protobuf.Any" && psr.rValue.IsValid() { + for psr.rValue.Type().Kind() == reflect.Ptr { + psr.rValue = reflect.Indirect(psr.rValue) + } + fieldValueAsAny, _ := psr.rValue.Interface().(anypb.Any) + msg, _ := fieldValueAsAny.UnmarshalNew() + + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = reflect.Indirect(v) + } + + return ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: psr.schemaOptions, + } + } else { + return psr + } +} + +func (psr ProtobufStructReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range psr.generateStructFields() { + fields = append(fields, arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + }) + } + + return fields +} + +type protobufListReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asList() protobufListReflection { + return protobufListReflection{*pfr} +} + +func (plr protobufListReflection) getDataType() arrow.DataType { + for li := range plr.generateListItems() { + return arrow.ListOf(li.getDataType()) + } + pfr := protobufFieldReflection{ + descriptor: plr.descriptor, + schemaOptions: plr.schemaOptions, + } + return arrow.ListOf(pfr.getDataType()) +} + +type protobufUnionReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asUnion() protobufUnionReflection { + return protobufUnionReflection{*pfr} +} + +func (pur protobufUnionReflection) isThisOne() bool { + for pur.rValue.Kind() == reflect.Ptr || pur.rValue.Kind() == reflect.Interface { + pur.rValue = pur.rValue.Elem() + } + return pur.rValue.Field(0).String() == pur.prValue.String() +} + +func (pur protobufUnionReflection) whichOne() arrow.UnionTypeCode { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pur.getUnionTypeCode(int32(pfr.descriptor.Number())) + } + } + // i.e. all null + return -1 +} + +func (pur protobufUnionReflection) getField() *protobufFieldReflection { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pfr + } + } + // i.e. all null + return nil +} + +func (pur protobufUnionReflection) getUnionTypeCode(n int32) arrow.UnionTypeCode { + //We use the index of the field number as there is a limit on the arrow.UnionTypeCode (127) + //which a protobuf Number could realistically exceed + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + if n == int32(fds.Get(i).Number()) { + return int8(i) + } + } + return -1 +} + +func (pur protobufUnionReflection) generateUnionFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + go func() { + defer close(out) + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + // Do not get stuck in a recursion loop + pfr.oneOfHandler = Null + if pfr.exclusionPolicy(pfr) { + continue + } + out <- pfr + } + }() + + return out +} + +func (pur protobufUnionReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range pur.generateUnionFields() { + fields = append(fields, pfr.arrowField()) + } + + return fields +} + +func (pur protobufUnionReflection) getDataType() arrow.DataType { + fds := pur.getArrowFields() + typeCodes := make([]arrow.UnionTypeCode, len(fds)) + for i := 0; i < len(fds); i++ { + typeCodes[i] = arrow.UnionTypeCode(i) + } + return arrow.DenseUnionOf(fds, typeCodes) +} + +type protobufDictReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asDictionary() protobufDictReflection { + return protobufDictReflection{*pfr} +} + +func (pdr protobufDictReflection) getDataType() arrow.DataType { + return &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, + ValueType: arrow.BinaryTypes.String, + Ordered: false, + } +} + +type protobufMapReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asMap() protobufMapReflection { + return protobufMapReflection{*pfr} +} + +func (pmr protobufMapReflection) getDataType() arrow.DataType { + for kvp := range pmr.generateKeyValuePairs() { + return kvp.getDataType() + } + return protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + schemaOptions: pmr.schemaOptions, + }, + }.getDataType() +} + +type protobufMapKeyValuePairReflection struct { + k protobufFieldReflection + v protobufFieldReflection +} + +func (pmr protobufMapKeyValuePairReflection) getDataType() arrow.DataType { + return arrow.MapOf(pmr.k.getDataType(), pmr.v.getDataType()) +} + +func (pmr protobufMapReflection) generateKeyValuePairs() chan protobufMapKeyValuePairReflection { + out := make(chan protobufMapKeyValuePairReflection) + + go func() { + defer close(out) + for _, k := range pmr.rValue.MapKeys() { + kvp := protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + prValue: getMapKey(k), + rValue: k, + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + prValue: pmr.prValue.Map().Get(protoreflect.MapKey(getMapKey(k))), + rValue: pmr.rValue.MapIndex(k), + schemaOptions: pmr.schemaOptions, + }, + } + out <- kvp + } + }() + + return out +} + +func getMapKey(v reflect.Value) protoreflect.Value { + switch v.Kind() { + case reflect.String: + return protoreflect.ValueOf(v.String()) + case reflect.Int32, reflect.Int64: + return protoreflect.ValueOf(v.Int()) + case reflect.Bool: + return protoreflect.ValueOf(v.Bool()) + case reflect.Uint32, reflect.Uint64: + return protoreflect.ValueOf(v.Uint()) + default: + panic("Unmapped protoreflect map key type") + } +} + +func (psr ProtobufStructReflection) generateStructFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (psr ProtobufStructReflection) generateFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (pfr *protobufFieldReflection) asStruct() ProtobufStructReflection { + psr := ProtobufStructReflection{ + descriptor: pfr.descriptor.Message(), + rValue: pfr.rValue, + schemaOptions: pfr.schemaOptions, + } + if pfr.prValue.IsValid() { + psr.message = pfr.prValue.Message() + } + psr = psr.unmarshallAny() + return psr +} + +func (psr ProtobufStructReflection) getDataType() arrow.DataType { + return arrow.StructOf(psr.getArrowFields()...) +} + +func (psr ProtobufStructReflection) getFieldByName(n string) *protobufFieldReflection { + fd := psr.descriptor.Fields().ByTextName(xstrings.ToSnakeCase(n)) + fv := psr.rValue + if fv.IsValid() { + if !fv.IsZero() { + for fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Interface { + fv = fv.Elem() + } + if fd.ContainingOneof() != nil { + n = string(fd.ContainingOneof().Name()) + } + fv = fv.FieldByName(xstrings.ToCamelCase(n)) + for fv.Kind() == reflect.Ptr { + fv = fv.Elem() + } + } + } + pfr := protobufFieldReflection{ + parent: &psr, + descriptor: fd, + rValue: fv, + schemaOptions: psr.schemaOptions, + } + if psr.message != nil { + pfr.prValue = psr.message.Get(fd) + } + return &pfr +} + +type protobufFieldReflection struct { + parent *ProtobufStructReflection + descriptor protoreflect.FieldDescriptor + prValue protoreflect.Value + rValue reflect.Value + schemaOptions +} + +func (pfr *protobufFieldReflection) isNull() bool { + for pfr.rValue.Kind() == reflect.Ptr { + if pfr.rValue.IsNil() { + return true + } + pfr.rValue = pfr.rValue.Elem() + } + + if !pfr.rValue.IsValid() && !pfr.prValue.IsValid() { + return true + } + return false +} + +func (pfr *protobufFieldReflection) arrowField() arrow.Field { + return arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + } +} + +func (pfr *protobufFieldReflection) protoreflectValue() protoreflect.Value { + return pfr.prValue +} + +func (pfr *protobufFieldReflection) reflectValue() reflect.Value { + return pfr.rValue +} + +func (pfr *protobufFieldReflection) getDescriptor() protoreflect.FieldDescriptor { + return pfr.descriptor +} + +func (pfr *protobufFieldReflection) name() string { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler != Null { + return pfr.fieldNameFormatter(string(pfr.descriptor.ContainingOneof().Name())) + } + return pfr.fieldNameFormatter(string(pfr.descriptor.Name())) +} + +func (pfr *protobufFieldReflection) arrowType() arrow.Type { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler == DenseUnion { + return arrow.DENSE_UNION + } + if pfr.isEnum() { + return arrow.DICTIONARY + } + if pfr.isStruct() { + return arrow.STRUCT + } + if pfr.isMap() { + return arrow.MAP + } + if pfr.isList() { + return arrow.LIST + } + return arrow.NULL +} + +func (pfr *protobufFieldReflection) isOneOf() bool { + return pfr.descriptor.ContainingOneof() != nil +} + +func (pfr *protobufFieldReflection) isEnum() bool { + return pfr.descriptor.Kind() == protoreflect.EnumKind +} + +func (pfr *protobufFieldReflection) isStruct() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && pfr.rValue.Kind() != reflect.Slice +} + +func (pfr *protobufFieldReflection) isMap() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && pfr.descriptor.IsMap() +} + +func (pfr *protobufFieldReflection) isList() bool { + return pfr.descriptor.IsList() && pfr.rValue.Kind() == reflect.Slice +} + +func (plr protobufListReflection) generateListItems() chan protobufFieldReflection { + out := make(chan protobufFieldReflection) + + go func() { + defer close(out) + for i := 0; i < plr.prValue.List().Len(); i++ { + out <- protobufFieldReflection{ + descriptor: plr.descriptor, + prValue: plr.prValue.List().Get(i), + rValue: plr.rValue.Index(i), + schemaOptions: plr.schemaOptions, + } + } + }() + + return out +} + +func (pfr *protobufFieldReflection) getDataType() arrow.DataType { + var dt arrow.DataType + + typeMap := map[protoreflect.Kind]arrow.DataType{ + //Numeric + protoreflect.Int32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Int64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Sint32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sint64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Uint32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Uint64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Fixed32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Fixed64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Sfixed32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sfixed64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.FloatKind: arrow.PrimitiveTypes.Float32, + protoreflect.DoubleKind: arrow.PrimitiveTypes.Float64, + //Binary + protoreflect.StringKind: arrow.BinaryTypes.String, + protoreflect.BytesKind: arrow.BinaryTypes.Binary, + //Fixed Width + protoreflect.BoolKind: arrow.FixedWidthTypes.Boolean, + // Special + protoreflect.EnumKind: nil, + protoreflect.MessageKind: nil, + } + dt = typeMap[pfr.descriptor.Kind()] + + switch pfr.arrowType() { + case arrow.DENSE_UNION: + dt = pfr.asUnion().getDataType() + case arrow.DICTIONARY: + dt = pfr.asDictionary().getDataType() + case arrow.LIST: + dt = pfr.asList().getDataType() + case arrow.MAP: + dt = pfr.asMap().getDataType() + case arrow.STRUCT: + dt = pfr.asStruct().getDataType() + } + + return dt +} + +func (sf SuperField) AppendValueOrNull(b array.Builder, mem memory.Allocator) { + pv := sf.protoreflectValue() + fd := sf.getDescriptor() + + if sf.isNull() { + b.AppendNull() + return + } + + switch b.Type().ID() { + case arrow.STRING: + b.(*array.StringBuilder).Append(pv.String()) + case arrow.BINARY: + b.(*array.BinaryBuilder).Append(pv.Bytes()) + case arrow.INT32: + b.(*array.Int32Builder).Append(int32(pv.Int())) + case arrow.INT64: + b.(*array.Int64Builder).Append(pv.Int()) + case arrow.FLOAT64: + b.(*array.Float64Builder).Append(pv.Float()) + case arrow.UINT32: + b.(*array.Uint32Builder).Append(uint32(pv.Uint())) + case arrow.UINT64: + b.(*array.Uint64Builder).Append(pv.Uint()) + case arrow.BOOL: + b.(*array.BooleanBuilder).Append(pv.Bool()) + case arrow.DENSE_UNION: + ub := b.(array.UnionBuilder) + pur := sf.asUnion() + if pur.whichOne() == -1 { + ub.AppendNull() + break + } + ub.Append(pur.whichOne()) + cb := ub.Child(int(pur.whichOne())) + SuperField{ + parent: sf.parent, + protobufReflection: pur.getField(), + Field: pur.arrowField(), + }.AppendValueOrNull(cb, mem) + case arrow.DICTIONARY: + db := b.(array.DictionaryBuilder) + err := db.AppendValueFromString(string(fd.Enum().Values().ByNumber(pv.Enum()).Name())) + if err != nil { + fmt.Println(err) + } + case arrow.STRUCT: + sb := b.(*array.StructBuilder) + sb.Append(true) + for i, field := range sf.Field.Type.(*arrow.StructType).Fields() { + SuperField{ + parent: sf.parent, + protobufReflection: sf.asStruct().getFieldByName(field.Name), + Field: field, + }.AppendValueOrNull(sb.FieldBuilder(i), mem) + } Review Comment: can you extract the `SuperField{}` creation from the loop, and just use the same instance over and over, just updating the members? ########## go/arrow/util/protobuf_reflect.go: ########## @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/huandu/xstrings" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" + "reflect" +) + +type OneOfHandler int + +const ( + // Null means do not wrap oneOfs in a union, they are treated as separate fields + Null OneOfHandler = iota + // DenseUnion maps the protobuf OneOf to an arrow.DENSE_UNION + DenseUnion +) + +type schemaOptions struct { + exclusionPolicy func(pfr *protobufFieldReflection) bool + fieldNameFormatter func(str string) string + oneOfHandler OneOfHandler +} + +// ProtobufStructReflection represents the metadata and values of a protobuf message +type ProtobufStructReflection struct { + descriptor protoreflect.MessageDescriptor + message protoreflect.Message + rValue reflect.Value + schemaOptions +} + +type protobufReflection interface { + name() string + arrowType() arrow.Type + protoreflectValue() protoreflect.Value + reflectValue() reflect.Value + getDescriptor() protoreflect.FieldDescriptor + isNull() bool + asDictionary() protobufDictReflection + asList() protobufListReflection + asMap() protobufMapReflection + asStruct() ProtobufStructReflection + asUnion() protobufUnionReflection +} + +type SuperField struct { + parent *ProtobufStructReflection + protobufReflection + arrow.Field +} + +type SuperMessage struct { + superFields []SuperField +} + +func (sm SuperMessage) Schema() *arrow.Schema { + var fields []arrow.Field + for _, sf := range sm.superFields { + fields = append(fields, sf.Field) + } + return arrow.NewSchema(fields, nil) +} + +func (sm SuperMessage) Record(mem memory.Allocator) arrow.Record { + if mem == nil { + mem = memory.NewGoAllocator() + } + + schema := sm.Schema() + + recordBuilder := array.NewRecordBuilder(mem, schema) + + var fieldNames []string + for i, sf := range sm.superFields { + sf.AppendValueOrNull(recordBuilder.Field(i), mem) + fieldNames = append(fieldNames, sf.protobufReflection.name()) + } + + var arrays []arrow.Array + for _, bldr := range recordBuilder.Fields() { + a := bldr.NewArray() + arrays = append(arrays, a) + } + + structArray, _ := array.NewStructArray(arrays, fieldNames) + + return array.RecordFromStructArray(structArray, schema) +} + +func NewSuperMessage(msg proto.Message, options ...option) *SuperMessage { + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + includeAll := func(pfr *protobufFieldReflection) bool { + return false + } + noFormatting := func(str string) string { + return str + } + psr := &ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: schemaOptions{ + exclusionPolicy: includeAll, + fieldNameFormatter: noFormatting, + oneOfHandler: Null, + }, + } + + for _, opt := range options { + opt(psr) + } + + var fields []SuperField + + for pfr := range psr.generateFields() { + fields = append(fields, SuperField{ + parent: psr, + protobufReflection: pfr, + Field: pfr.arrowField(), + }) + } + + return &SuperMessage{superFields: fields} +} + +type option func(*ProtobufStructReflection) + +// WithExclusionPolicy is an option for a ProtobufStructReflection +// WithExclusionPolicy acts as a deny filter on the fields of a protobuf message +// i.e. prevents them from being included in the schema. +// A use case for this is to exclude fields containing PII. +func WithExclusionPolicy(ex func(pfr *protobufFieldReflection) bool) option { + return func(psr *ProtobufStructReflection) { + psr.exclusionPolicy = ex + } +} + +// WithFieldNameFormatter is an option for a ProtobufStructReflection +// WithFieldNameFormatter enables customisation of the field names in the arrow schema +// By default, the field names are taken from the protobuf message (.proto file) +func WithFieldNameFormatter(formatter func(str string) string) option { + return func(psr *ProtobufStructReflection) { + psr.fieldNameFormatter = formatter + } +} + +// WithOneOfHandler is an option for a ProtobufStructReflection +// WithOneOfHandler enables customisation of the protobuf oneOf type in the arrow schema +// By default, the oneOfs are mapped to separate columns +func WithOneOfHandler(oneOfHandler OneOfHandler) option { + return func(psr *ProtobufStructReflection) { + psr.oneOfHandler = oneOfHandler + } +} + +func (psr ProtobufStructReflection) unmarshallAny() ProtobufStructReflection { + if psr.descriptor.FullName() == "google.protobuf.Any" && psr.rValue.IsValid() { + for psr.rValue.Type().Kind() == reflect.Ptr { + psr.rValue = reflect.Indirect(psr.rValue) + } + fieldValueAsAny, _ := psr.rValue.Interface().(anypb.Any) + msg, _ := fieldValueAsAny.UnmarshalNew() + + v := reflect.ValueOf(msg) + for v.Kind() == reflect.Ptr { + v = reflect.Indirect(v) + } + + return ProtobufStructReflection{ + descriptor: msg.ProtoReflect().Descriptor(), + message: msg.ProtoReflect(), + rValue: v, + schemaOptions: psr.schemaOptions, + } + } else { + return psr + } +} + +func (psr ProtobufStructReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range psr.generateStructFields() { + fields = append(fields, arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + }) + } + + return fields +} + +type protobufListReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asList() protobufListReflection { + return protobufListReflection{*pfr} +} + +func (plr protobufListReflection) getDataType() arrow.DataType { + for li := range plr.generateListItems() { + return arrow.ListOf(li.getDataType()) + } + pfr := protobufFieldReflection{ + descriptor: plr.descriptor, + schemaOptions: plr.schemaOptions, + } + return arrow.ListOf(pfr.getDataType()) +} + +type protobufUnionReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asUnion() protobufUnionReflection { + return protobufUnionReflection{*pfr} +} + +func (pur protobufUnionReflection) isThisOne() bool { + for pur.rValue.Kind() == reflect.Ptr || pur.rValue.Kind() == reflect.Interface { + pur.rValue = pur.rValue.Elem() + } + return pur.rValue.Field(0).String() == pur.prValue.String() +} + +func (pur protobufUnionReflection) whichOne() arrow.UnionTypeCode { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pur.getUnionTypeCode(int32(pfr.descriptor.Number())) + } + } + // i.e. all null + return -1 +} + +func (pur protobufUnionReflection) getField() *protobufFieldReflection { + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + if pfr.asUnion().isThisOne() { + return pfr + } + } + // i.e. all null + return nil +} + +func (pur protobufUnionReflection) getUnionTypeCode(n int32) arrow.UnionTypeCode { + //We use the index of the field number as there is a limit on the arrow.UnionTypeCode (127) + //which a protobuf Number could realistically exceed + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + if n == int32(fds.Get(i).Number()) { + return int8(i) + } + } + return -1 +} + +func (pur protobufUnionReflection) generateUnionFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + go func() { + defer close(out) + fds := pur.descriptor.ContainingOneof().Fields() + for i := 0; i < fds.Len(); i++ { + pfr := pur.parent.getFieldByName(string(fds.Get(i).Name())) + // Do not get stuck in a recursion loop + pfr.oneOfHandler = Null + if pfr.exclusionPolicy(pfr) { + continue + } + out <- pfr + } + }() + + return out +} + +func (pur protobufUnionReflection) getArrowFields() []arrow.Field { + var fields []arrow.Field + + for pfr := range pur.generateUnionFields() { + fields = append(fields, pfr.arrowField()) + } + + return fields +} + +func (pur protobufUnionReflection) getDataType() arrow.DataType { + fds := pur.getArrowFields() + typeCodes := make([]arrow.UnionTypeCode, len(fds)) + for i := 0; i < len(fds); i++ { + typeCodes[i] = arrow.UnionTypeCode(i) + } + return arrow.DenseUnionOf(fds, typeCodes) +} + +type protobufDictReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asDictionary() protobufDictReflection { + return protobufDictReflection{*pfr} +} + +func (pdr protobufDictReflection) getDataType() arrow.DataType { + return &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, + ValueType: arrow.BinaryTypes.String, + Ordered: false, + } +} + +type protobufMapReflection struct { + protobufFieldReflection +} + +func (pfr *protobufFieldReflection) asMap() protobufMapReflection { + return protobufMapReflection{*pfr} +} + +func (pmr protobufMapReflection) getDataType() arrow.DataType { + for kvp := range pmr.generateKeyValuePairs() { + return kvp.getDataType() + } + return protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + schemaOptions: pmr.schemaOptions, + }, + }.getDataType() +} + +type protobufMapKeyValuePairReflection struct { + k protobufFieldReflection + v protobufFieldReflection +} + +func (pmr protobufMapKeyValuePairReflection) getDataType() arrow.DataType { + return arrow.MapOf(pmr.k.getDataType(), pmr.v.getDataType()) +} + +func (pmr protobufMapReflection) generateKeyValuePairs() chan protobufMapKeyValuePairReflection { + out := make(chan protobufMapKeyValuePairReflection) + + go func() { + defer close(out) + for _, k := range pmr.rValue.MapKeys() { + kvp := protobufMapKeyValuePairReflection{ + k: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + prValue: getMapKey(k), + rValue: k, + schemaOptions: pmr.schemaOptions, + }, + v: protobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + prValue: pmr.prValue.Map().Get(protoreflect.MapKey(getMapKey(k))), + rValue: pmr.rValue.MapIndex(k), + schemaOptions: pmr.schemaOptions, + }, + } + out <- kvp + } + }() + + return out +} + +func getMapKey(v reflect.Value) protoreflect.Value { + switch v.Kind() { + case reflect.String: + return protoreflect.ValueOf(v.String()) + case reflect.Int32, reflect.Int64: + return protoreflect.ValueOf(v.Int()) + case reflect.Bool: + return protoreflect.ValueOf(v.Bool()) + case reflect.Uint32, reflect.Uint64: + return protoreflect.ValueOf(v.Uint()) + default: + panic("Unmapped protoreflect map key type") + } +} + +func (psr ProtobufStructReflection) generateStructFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (psr ProtobufStructReflection) generateFields() chan *protobufFieldReflection { + out := make(chan *protobufFieldReflection) + + go func() { + defer close(out) + fds := psr.descriptor.Fields() + for i := 0; i < fds.Len(); i++ { + pfr := psr.getFieldByName(string(fds.Get(i).Name())) + if psr.exclusionPolicy(pfr) { + continue + } + if pfr.arrowType() == arrow.DENSE_UNION { + if pfr.descriptor.Number() != pfr.descriptor.ContainingOneof().Fields().Get(0).Number() { + continue + } + } + out <- pfr + } + }() + + return out +} + +func (pfr *protobufFieldReflection) asStruct() ProtobufStructReflection { + psr := ProtobufStructReflection{ + descriptor: pfr.descriptor.Message(), + rValue: pfr.rValue, + schemaOptions: pfr.schemaOptions, + } + if pfr.prValue.IsValid() { + psr.message = pfr.prValue.Message() + } + psr = psr.unmarshallAny() + return psr +} + +func (psr ProtobufStructReflection) getDataType() arrow.DataType { + return arrow.StructOf(psr.getArrowFields()...) +} + +func (psr ProtobufStructReflection) getFieldByName(n string) *protobufFieldReflection { + fd := psr.descriptor.Fields().ByTextName(xstrings.ToSnakeCase(n)) + fv := psr.rValue + if fv.IsValid() { + if !fv.IsZero() { + for fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Interface { + fv = fv.Elem() + } + if fd.ContainingOneof() != nil { + n = string(fd.ContainingOneof().Name()) + } + fv = fv.FieldByName(xstrings.ToCamelCase(n)) + for fv.Kind() == reflect.Ptr { + fv = fv.Elem() + } + } + } + pfr := protobufFieldReflection{ + parent: &psr, + descriptor: fd, + rValue: fv, + schemaOptions: psr.schemaOptions, + } + if psr.message != nil { + pfr.prValue = psr.message.Get(fd) + } + return &pfr +} + +type protobufFieldReflection struct { + parent *ProtobufStructReflection + descriptor protoreflect.FieldDescriptor + prValue protoreflect.Value + rValue reflect.Value + schemaOptions +} + +func (pfr *protobufFieldReflection) isNull() bool { + for pfr.rValue.Kind() == reflect.Ptr { + if pfr.rValue.IsNil() { + return true + } + pfr.rValue = pfr.rValue.Elem() + } + + if !pfr.rValue.IsValid() && !pfr.prValue.IsValid() { + return true + } + return false +} + +func (pfr *protobufFieldReflection) arrowField() arrow.Field { + return arrow.Field{ + Name: pfr.name(), + Type: pfr.getDataType(), + Nullable: true, + } +} + +func (pfr *protobufFieldReflection) protoreflectValue() protoreflect.Value { + return pfr.prValue +} + +func (pfr *protobufFieldReflection) reflectValue() reflect.Value { + return pfr.rValue +} + +func (pfr *protobufFieldReflection) getDescriptor() protoreflect.FieldDescriptor { + return pfr.descriptor +} + +func (pfr *protobufFieldReflection) name() string { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler != Null { + return pfr.fieldNameFormatter(string(pfr.descriptor.ContainingOneof().Name())) + } + return pfr.fieldNameFormatter(string(pfr.descriptor.Name())) +} + +func (pfr *protobufFieldReflection) arrowType() arrow.Type { + if pfr.isOneOf() && pfr.schemaOptions.oneOfHandler == DenseUnion { + return arrow.DENSE_UNION + } + if pfr.isEnum() { + return arrow.DICTIONARY + } + if pfr.isStruct() { + return arrow.STRUCT + } + if pfr.isMap() { + return arrow.MAP + } + if pfr.isList() { + return arrow.LIST + } + return arrow.NULL +} + +func (pfr *protobufFieldReflection) isOneOf() bool { + return pfr.descriptor.ContainingOneof() != nil +} + +func (pfr *protobufFieldReflection) isEnum() bool { + return pfr.descriptor.Kind() == protoreflect.EnumKind +} + +func (pfr *protobufFieldReflection) isStruct() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && pfr.rValue.Kind() != reflect.Slice +} + +func (pfr *protobufFieldReflection) isMap() bool { + return pfr.descriptor.Kind() == protoreflect.MessageKind && pfr.descriptor.IsMap() +} + +func (pfr *protobufFieldReflection) isList() bool { + return pfr.descriptor.IsList() && pfr.rValue.Kind() == reflect.Slice +} + +func (plr protobufListReflection) generateListItems() chan protobufFieldReflection { + out := make(chan protobufFieldReflection) + + go func() { + defer close(out) + for i := 0; i < plr.prValue.List().Len(); i++ { + out <- protobufFieldReflection{ + descriptor: plr.descriptor, + prValue: plr.prValue.List().Get(i), + rValue: plr.rValue.Index(i), + schemaOptions: plr.schemaOptions, + } + } + }() + + return out +} + +func (pfr *protobufFieldReflection) getDataType() arrow.DataType { + var dt arrow.DataType + + typeMap := map[protoreflect.Kind]arrow.DataType{ + //Numeric + protoreflect.Int32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Int64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Sint32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sint64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.Uint32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Uint64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Fixed32Kind: arrow.PrimitiveTypes.Uint32, + protoreflect.Fixed64Kind: arrow.PrimitiveTypes.Uint64, + protoreflect.Sfixed32Kind: arrow.PrimitiveTypes.Int32, + protoreflect.Sfixed64Kind: arrow.PrimitiveTypes.Int64, + protoreflect.FloatKind: arrow.PrimitiveTypes.Float32, + protoreflect.DoubleKind: arrow.PrimitiveTypes.Float64, + //Binary + protoreflect.StringKind: arrow.BinaryTypes.String, + protoreflect.BytesKind: arrow.BinaryTypes.Binary, + //Fixed Width + protoreflect.BoolKind: arrow.FixedWidthTypes.Boolean, + // Special + protoreflect.EnumKind: nil, + protoreflect.MessageKind: nil, + } + dt = typeMap[pfr.descriptor.Kind()] + + switch pfr.arrowType() { + case arrow.DENSE_UNION: + dt = pfr.asUnion().getDataType() + case arrow.DICTIONARY: + dt = pfr.asDictionary().getDataType() + case arrow.LIST: + dt = pfr.asList().getDataType() + case arrow.MAP: + dt = pfr.asMap().getDataType() + case arrow.STRUCT: + dt = pfr.asStruct().getDataType() + } + + return dt +} + +func (sf SuperField) AppendValueOrNull(b array.Builder, mem memory.Allocator) { + pv := sf.protoreflectValue() + fd := sf.getDescriptor() + + if sf.isNull() { + b.AppendNull() + return + } + + switch b.Type().ID() { + case arrow.STRING: + b.(*array.StringBuilder).Append(pv.String()) + case arrow.BINARY: + b.(*array.BinaryBuilder).Append(pv.Bytes()) + case arrow.INT32: + b.(*array.Int32Builder).Append(int32(pv.Int())) + case arrow.INT64: + b.(*array.Int64Builder).Append(pv.Int()) + case arrow.FLOAT64: + b.(*array.Float64Builder).Append(pv.Float()) + case arrow.UINT32: + b.(*array.Uint32Builder).Append(uint32(pv.Uint())) + case arrow.UINT64: + b.(*array.Uint64Builder).Append(pv.Uint()) + case arrow.BOOL: + b.(*array.BooleanBuilder).Append(pv.Bool()) + case arrow.DENSE_UNION: + ub := b.(array.UnionBuilder) + pur := sf.asUnion() + if pur.whichOne() == -1 { + ub.AppendNull() + break + } + ub.Append(pur.whichOne()) + cb := ub.Child(int(pur.whichOne())) + SuperField{ + parent: sf.parent, + protobufReflection: pur.getField(), + Field: pur.arrowField(), + }.AppendValueOrNull(cb, mem) + case arrow.DICTIONARY: + db := b.(array.DictionaryBuilder) + err := db.AppendValueFromString(string(fd.Enum().Values().ByNumber(pv.Enum()).Name())) + if err != nil { + fmt.Println(err) + } + case arrow.STRUCT: + sb := b.(*array.StructBuilder) + sb.Append(true) + for i, field := range sf.Field.Type.(*arrow.StructType).Fields() { + SuperField{ + parent: sf.parent, + protobufReflection: sf.asStruct().getFieldByName(field.Name), + Field: field, + }.AppendValueOrNull(sb.FieldBuilder(i), mem) + } + case arrow.LIST: + lb := b.(*array.ListBuilder) + l := pv.List().Len() + if l == 0 { + lb.AppendEmptyValue() + break + } + lb.ValueBuilder().Reserve(l) + lb.Append(true) + for li := range sf.asList().generateListItems() { + SuperField{ + parent: sf.parent, + protobufReflection: &li, + Field: sf.Field.Type.(*arrow.ListType).ElemField(), + }.AppendValueOrNull(lb.ValueBuilder(), mem) + } + case arrow.MAP: + mb := b.(*array.MapBuilder) + l := pv.Map().Len() + if l == 0 { + mb.AppendEmptyValue() + break + } + mb.KeyBuilder().Reserve(l) + mb.ItemBuilder().Reserve(l) + mb.Append(true) + + for kvp := range sf.asMap().generateKeyValuePairs() { + SuperField{ + parent: sf.parent, + protobufReflection: &kvp.k, + Field: sf.Field.Type.(*arrow.MapType).KeyField(), + }.AppendValueOrNull(mb.KeyBuilder(), mem) + SuperField{ + parent: sf.parent, + protobufReflection: &kvp.v, + Field: sf.Field.Type.(*arrow.MapType).ItemField(), + }.AppendValueOrNull(mb.ItemBuilder(), mem) + } + default: + fmt.Printf("No logic for type %s", b.Type().ID()) + } Review Comment: return an error instead, we shouldn't be outputting to the terminal here. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
