zeroshade commented on code in PR #40496:
URL: https://github.com/apache/arrow/pull/40496#discussion_r1530699007
##########
go/arrow/util/protobuf_reflect.go:
##########
@@ -0,0 +1,447 @@
+package util
+
+import (
+ "fmt"
+ "github.com/apache/arrow/go/v16/arrow"
+ "github.com/apache/arrow/go/v16/arrow/array"
+ "github.com/apache/arrow/go/v16/arrow/memory"
+ "github.com/huandu/xstrings"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/types/known/anypb"
+ "reflect"
+)
+
+type SchemaOptions struct {
+ exclusionPolicy func(pfr ProtobufFieldReflection) bool
+ fieldNameFormatter func(str string) string
+}
+
+type ProtobufStructReflection struct {
+ descriptor protoreflect.MessageDescriptor
+ message protoreflect.Message
+ rValue reflect.Value
+ SchemaOptions
+}
+
+type Option func(*ProtobufStructReflection)
+
+func NewProtobufStructReflection(msg proto.Message, options ...Option)
*ProtobufStructReflection {
+ v := reflect.ValueOf(msg)
+ for v.Kind() == reflect.Ptr {
+ v = v.Elem()
+ }
+ includeAll := func(pfr ProtobufFieldReflection) bool {
+ return false
+ }
+ noFormatting := func(str string) string {
+ return str
+ }
+ psr := &ProtobufStructReflection{
+ descriptor: msg.ProtoReflect().Descriptor(),
+ message: msg.ProtoReflect(),
+ rValue: v,
+ SchemaOptions: SchemaOptions{
+ exclusionPolicy: includeAll,
+ fieldNameFormatter: noFormatting,
+ },
+ }
+
+ for _, opt := range options {
+ opt(psr)
+ }
+
+ return psr
+}
+
+func WithExclusionPolicy(ex func(pfr ProtobufFieldReflection) bool) Option {
+ return func(psr *ProtobufStructReflection) {
+ psr.exclusionPolicy = ex
+ }
+}
+
+func WithFieldNameFormatter(formatter func(str string) string) Option {
+ return func(psr *ProtobufStructReflection) {
+ psr.fieldNameFormatter = formatter
+ }
+}
+
+func (psr ProtobufStructReflection) unmarshallAny() ProtobufStructReflection {
+ if psr.descriptor.FullName() == "google.protobuf.Any" {
+ for psr.rValue.Type().Kind() == reflect.Ptr {
+ psr.rValue = reflect.Indirect(psr.rValue)
+ }
+ fieldValueAsAny, _ := psr.rValue.Interface().(anypb.Any)
+ msg, _ := fieldValueAsAny.UnmarshalNew()
+
+ v := reflect.ValueOf(msg)
+ for v.Kind() == reflect.Ptr {
+ v = reflect.Indirect(v)
+ }
+
+ return ProtobufStructReflection{
+ descriptor: msg.ProtoReflect().Descriptor(),
+ message: msg.ProtoReflect(),
+ rValue: v,
+ SchemaOptions: psr.SchemaOptions,
+ }
+ } else {
+ return psr
+ }
+}
+
+func (psr ProtobufStructReflection) GetArrowFields() []arrow.Field {
+ var fields []arrow.Field
+
+ for pfr := range psr.generateStructFields() {
+ fields = append(fields, arrow.Field{
+ Name:
psr.fieldNameFormatter(string(pfr.descriptor.Name())),
+ Type: pfr.getDataType(),
+ Nullable: true,
+ })
+ }
+
+ return fields
+}
+
+func (psr ProtobufStructReflection) GetSchema() *arrow.Schema {
+ return arrow.NewSchema(psr.GetArrowFields(), nil)
+}
+
+type ProtobufListReflection struct {
+ ProtobufFieldReflection
+}
+
+func (pfr ProtobufFieldReflection) AsList() ProtobufListReflection {
+ return ProtobufListReflection{pfr}
+}
+
+func (plr ProtobufListReflection) getDataType() arrow.DataType {
+ for li := range plr.generateListItems() {
+ return arrow.ListOf(li.getDataType())
+ }
+ return nil
+}
+
+func (pfr ProtobufFieldReflection) AsMap() ProtobufMapReflection {
+ return ProtobufMapReflection{pfr}
+}
+
+type ProtobufMapReflection struct {
+ ProtobufFieldReflection
+}
+
+func (pmr ProtobufMapReflection) getDataType() arrow.DataType {
+ for kvp := range pmr.generateKeyValuePairs() {
+ return kvp.getDataType()
+ }
+ return nil
+}
+
+type ProtobufMapKeyValuePairReflection struct {
+ k ProtobufFieldReflection
+ v ProtobufFieldReflection
+}
+
+func (pmr ProtobufMapKeyValuePairReflection) getDataType() arrow.DataType {
+ return arrow.MapOf(pmr.k.getDataType(), pmr.v.getDataType())
+}
+
+func (pmr ProtobufMapReflection) generateKeyValuePairs() chan
ProtobufMapKeyValuePairReflection {
+ out := make(chan ProtobufMapKeyValuePairReflection)
+
+ go func() {
+ defer close(out)
+ for _, k := range pmr.rValue.MapKeys() {
+ kvp := ProtobufMapKeyValuePairReflection{
+ k: ProtobufFieldReflection{
+ descriptor: pmr.descriptor.MapKey(),
Review Comment:
You're right that a gofunc + channel is similar to a python generator. But I
don't think it's necessarily needed for this as it shouldn't be too expensive
to create the slice/map and return it. Potentially caching the result if it
makes sense to do so (but that might be premature optimization)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]