zeroshade commented on code in PR #34631:
URL: https://github.com/apache/arrow/pull/34631#discussion_r1142290526


##########
go/parquet/pqarrow/schema.go:
##########
@@ -354,7 +358,15 @@ func fieldToNode(name string, field arrow.Field, props 
*parquet.WriterProperties
                return fieldToNode(name, arrow.Field{Name: name, Type: 
dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata},
                        props, arrprops)
        case arrow.EXTENSION:
-               return nil, xerrors.New("not implemented yet")
+               return fieldToNode(name, arrow.Field{
+                       Name: name,
+                       Type: field.Type.(arrow.ExtensionType).StorageType(),
+                       Nullable: field.Nullable,
+                       Metadata: arrow.MetadataFrom(map[string]string{
+                               ExtensionTypeKeyName: 
field.Type.(arrow.ExtensionType).ExtensionName(),
+                       }),
+                       // Metadata: ,

Review Comment:
   what's this comment for?



##########
go/parquet/pqarrow/schema.go:
##########
@@ -354,7 +358,15 @@ func fieldToNode(name string, field arrow.Field, props 
*parquet.WriterProperties
                return fieldToNode(name, arrow.Field{Name: name, Type: 
dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata},
                        props, arrprops)
        case arrow.EXTENSION:
-               return nil, xerrors.New("not implemented yet")
+               return fieldToNode(name, arrow.Field{
+                       Name: name,
+                       Type: field.Type.(arrow.ExtensionType).StorageType(),
+                       Nullable: field.Nullable,
+                       Metadata: arrow.MetadataFrom(map[string]string{
+                               ExtensionTypeKeyName: 
field.Type.(arrow.ExtensionType).ExtensionName(),
+                       }),

Review Comment:
   You also need to add `ExtensionMetadataKeyName: 
field.Type.(arrow.ExtensionType).Serialize()`



##########
go/parquet/pqarrow/uuid.go:
##########
@@ -0,0 +1,198 @@
+package pqarrow
+
+import (
+       "bytes"
+       "fmt"
+       "reflect"
+       "strings"
+
+       "github.com/goccy/go-json"
+
+       "github.com/apache/arrow/go/v12/arrow"
+       "github.com/apache/arrow/go/v12/arrow/array"
+       "github.com/google/uuid"
+)
+
+type UUIDBuilder struct {
+       *array.ExtensionBuilder
+}
+
+func NewUUIDBuilder(bldr *array.ExtensionBuilder) *UUIDBuilder {
+       b := &UUIDBuilder{
+               ExtensionBuilder: bldr,
+       }
+       return b
+}
+
+func (b *UUIDBuilder) Append(v uuid.UUID) {
+       b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:])
+}
+
+func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) {
+       data := make([][]byte, len(v))
+       for i, v := range v {
+               data[i] = v[:]
+       }
+       
b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data, 
valid)
+}
+
+func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error {
+       t, err := dec.Token()
+       if err != nil {
+               return err
+       }
+
+       var val uuid.UUID
+       switch v := t.(type) {
+       case string:
+               data, err := uuid.Parse(v)
+               if err != nil {
+                       return err
+               }
+               val = data
+       case nil:
+               b.AppendNull()
+               return nil
+       default:
+               return &json.UnmarshalTypeError{
+                       Value:  fmt.Sprint(t),
+                       Type:   reflect.TypeOf([]byte{}),
+                       Offset: dec.InputOffset(),
+                       Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16),
+               }
+       }
+
+       if len(val) != 16 {
+               return &json.UnmarshalTypeError{
+                       Value:  fmt.Sprint(val),
+                       Type:   reflect.TypeOf([]byte{}),
+                       Offset: dec.InputOffset(),
+                       Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16),
+               }
+       }
+       b.Append(val)
+       return nil
+}
+
+func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error {
+       for dec.More() {
+               if err := b.UnmarshalOne(dec); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+func (b *UUIDBuilder) UnmarshalJSON(data []byte) error {
+       dec := json.NewDecoder(bytes.NewReader(data))
+       t, err := dec.Token()
+       if err != nil {
+               return err
+       }
+
+       if delim, ok := t.(json.Delim); !ok || delim != '[' {
+               return fmt.Errorf("fixed size binary builder must unpack from 
json array, found %s", delim)
+       }
+
+       return b.Unmarshal(dec)
+}
+
+// UUIDArray is a simple array which is a FixedSizeBinary(16)
+type UUIDArray struct {
+       array.ExtensionArrayBase
+}
+
+func (a UUIDArray) String() string {
+       arr := a.Storage().(*array.FixedSizeBinary)
+       o := new(strings.Builder)
+       o.WriteString("[")
+       for i := 0; i < arr.Len(); i++ {
+               if i > 0 {
+                       o.WriteString(" ")
+               }
+               switch {
+               case a.IsNull(i):
+                       o.WriteString("(null)")
+               default:
+                       uuidStr, err := uuid.FromBytes(arr.Value(i))
+                       if err != nil {
+                               panic(fmt.Errorf("invalid uuid: %w", err))
+                       }
+                       fmt.Fprintf(o, "%q", uuidStr)
+               }
+       }
+       o.WriteString("]")
+       return o.String()
+}
+
+func (a *UUIDArray) MarshalJSON() ([]byte, error) {
+       arr := a.Storage().(*array.FixedSizeBinary)
+       vals := make([]interface{}, a.Len())
+       for i := 0; i < a.Len(); i++ {
+               if a.IsValid(i) {
+                       uuidStr, err := uuid.FromBytes(arr.Value(i))
+                       if err != nil {
+                               panic(fmt.Errorf("invalid uuid: %w", err))
+                       }
+                       vals[i] = uuidStr
+               } else {
+                       vals[i] = nil
+               }
+       }
+       return json.Marshal(vals)
+}
+
+func (a *UUIDArray) GetOneForMarshal(i int) interface{} {
+       arr := a.Storage().(*array.FixedSizeBinary)
+       if a.IsValid(i) {
+               uuidObj, err := uuid.FromBytes(arr.Value(i))
+               if err != nil {
+                       panic(fmt.Errorf("invalid uuid: %w", err))
+               }
+               return uuidObj
+       }
+       return nil
+}
+
+// UUIDType is a simple extension type that represents a FixedSizeBinary(16)
+// to be used for representing UUIDs
+type UUIDType struct {
+       arrow.ExtensionBase
+}
+
+// NewUUIDType is a convenience function to create an instance of UuidType
+// with the correct storage type
+func NewUUIDType() *UUIDType {
+       return &UUIDType{
+               ExtensionBase: arrow.ExtensionBase{
+                       Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}}
+}
+
+// ArrayType returns TypeOf(UuidArray) for constructing uuid arrays
+func (UUIDType) ArrayType() reflect.Type { return reflect.TypeOf(UUIDArray{}) }
+
+func (UUIDType) ExtensionName() string { return "uuid" }
+
+// Serialize returns "uuid-serialized" for testing proper metadata passing
+func (UUIDType) Serialize() string { return "uuid-serialized" }
+
+// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16} 
and the data to be
+// "uuid-serialized" in order to correctly create a UuidType for testing 
deserialize.
+func (UUIDType) Deserialize(storageType arrow.DataType, data string) 
(arrow.ExtensionType, error) {
+       if string(data) != "uuid-serialized" {
+               return nil, fmt.Errorf("type identifier did not match: '%s'", 
string(data))
+       }
+       if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth: 
16}) {
+               return nil, fmt.Errorf("invalid storage type for UuidType: %s", 
storageType.Name())
+       }
+       return NewUUIDType(), nil
+}
+
+// UuidTypes are equal if both are named "uuid"
+func (u UUIDType) ExtensionEquals(other arrow.ExtensionType) bool {
+       return u.ExtensionName() == other.ExtensionName()
+}
+
+func (u UUIDType) NewBuilder(bldr *array.ExtensionBuilder) array.Builder {
+       return NewUUIDBuilder(bldr)
+}

Review Comment:
   missing newline at the end of the file



##########
go/parquet/pqarrow/schema.go:
##########
@@ -32,6 +32,10 @@ import (
        "golang.org/x/xerrors"
 )
 
+// constants for the extension type metadata keys for the type name and
+// any extension metadata to be passed to deserialize.
+const  ExtensionTypeKeyName = "ARROW:extension:name"
+

Review Comment:
   you should just use `const ExtensionTypeKeyName = ipc.ExtensionTypeKeyName` 
or just use `ipc.ExtensionTypeKeyName` directly since it's already exported 
there. No need to hardcode it here.



##########
go/parquet/pqarrow/schema_test.go:
##########
@@ -36,9 +36,13 @@ func TestGetOriginSchemaBase64(t *testing.T) {
        origArrSc := arrow.NewSchema([]arrow.Field{
                {Name: "f1", Type: arrow.BinaryTypes.String, Metadata: md},
                {Name: "f2", Type: arrow.PrimitiveTypes.Int64, Metadata: md},
+               {Name: "uuid", Type: pqarrow.NewUUIDType(), Metadata: md},
        }, nil)
 
        arrSerializedSc := flight.SerializeSchema(origArrSc, 
memory.DefaultAllocator)
+       if err := arrow.RegisterExtensionType(pqarrow.NewUUIDType()); err != 
nil {

Review Comment:
   don't forget to defer the unregister call. `defer 
arrow.UnregisterExtensionType("uuid")` it uses the extension name to 
unregister, so you can either hardcode it or you can do something like:
   
   ```go
   uuidType := pqarrow.NewUUIDType()
   if err := arrow.RegisterExtensionType(uuidType); err != nil {
       t.Fatal(err)
   }
   defer arrow.UnregisterExtensionType(uuidType.ExtensionName())
   ``` 



##########
go/parquet/pqarrow/uuid.go:
##########
@@ -0,0 +1,198 @@
+package pqarrow
+
+import (
+       "bytes"
+       "fmt"
+       "reflect"
+       "strings"
+
+       "github.com/goccy/go-json"
+
+       "github.com/apache/arrow/go/v12/arrow"
+       "github.com/apache/arrow/go/v12/arrow/array"
+       "github.com/google/uuid"
+)
+
+type UUIDBuilder struct {

Review Comment:
   maybe we should shift the existing UUID extension type out of the 
`arrow/internal/testing/types` pkg and put it somewhere in the shared 
`internal` package that both `arrow` and `parquet` packages can use. that way 
we don't have to duplicate the code.
   
   Not saying you have to, just an idea. I'm also fine with having this 
separately in here, though since it is only used for testing I'd prefer to keep 
this to a `_test` package or a `_test.go` file if you don't shift it to the 
shared internal package. Basically I don't want this to be externally 
accessible please.



##########
go/parquet/pqarrow/schema.go:
##########
@@ -948,7 +960,15 @@ func getNestedFactory(origin, inferred arrow.DataType) 
func(fieldList []arrow.Fi
 func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) 
(modified bool, err error) {
        nchildren := len(inferred.Children)
        switch origin.Type.ID() {
-       case arrow.EXTENSION, arrow.SPARSE_UNION, arrow.DENSE_UNION:
+       case arrow.EXTENSION:
+               extType := 
arrow.GetExtensionType(origin.Type.(arrow.ExtensionType).ExtensionName())
+               if extType == nil {
+                       err = xerrors.Errorf("arrow: extension type %q not 
registered", origin.Type.Name())
+                       return
+               }
+               inferred.Field.Type = extType

Review Comment:
   why `GetExtensionType` using the name when you already have the extension 
type here via `origin.Type` ?
   
   If the extension type wasn't registered, then when the schema was 
deserialized you'd only have the storage type anyway and never get into this 
spot. In addition if the extension type is a parameterized type, then you're 
losing the metadata that was serialized into the `ExtensionMetadataKey`. 
   
   I think the better solution here would probably be to do something like:
   
   ```go
   extType := origin.Type.(arrow.ExtensionType)
   modified, err = applyOriginalStorageMetadata(arrow.Field{
                    Type: extType.StorageType(), 
                    Metadata: arrow.NewMetadata(
                               []string{ipc.ExtensionTypeKeyName, 
ExtensionMetadataKeyName},
                               []string{extType.ExtensionName(), 
extType.Serialize()}),
              }, inferred)
   if err != nil {
       return
   }
   
   inferred.Field.Type = extType
   modified = true
   ```
   
   Or something to that extent. That way it can properly handle extension types 
that have children / are defined in terms of a struct/list/etc. If the 
extension type wasn't registered, that's pretty much what would happen anyways 
since it would deserialize the schema as the storage type + the metadata.
   
   What do you think? Might also be worthwhile to add a testing extension type 
that is parameterized and whose storage type is one of the nested types just to 
ensure it is handled correctly. You can find examples of parameteric types in 
the `arrow/internal/testing/types/extension_types.go` file.



##########
go/parquet/pqarrow/schema_test.go:
##########
@@ -36,9 +36,13 @@ func TestGetOriginSchemaBase64(t *testing.T) {
        origArrSc := arrow.NewSchema([]arrow.Field{
                {Name: "f1", Type: arrow.BinaryTypes.String, Metadata: md},
                {Name: "f2", Type: arrow.PrimitiveTypes.Int64, Metadata: md},
+               {Name: "uuid", Type: pqarrow.NewUUIDType(), Metadata: md},

Review Comment:
   in addition to the schema `ToParquet`/`FromParquet` functions. Please add 
some tests for reading and writing a table / record batches which contain 
extension types. Such as a simple round trip:
   
   ```go
   expectedTable := // ... table with extension type column
   defer expectedTable.Release()
   
   var buf bytes.Buffer
   require.NoError(t, pqarrow.WriteTable(expectedTable, &buf, ....))
   
   actualTable, err := pqarrow.ReadTable(context.Background(), 
bytes.NewReader(buf.Bytes()), .....)
   require.NoError(t, err)
   defer actualTable.Release()
   
   assert.True(t, array.TableEqual(expectedTable, actualTable), ...)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to