zeroshade commented on code in PR #34631:
URL: https://github.com/apache/arrow/pull/34631#discussion_r1142290526
##########
go/parquet/pqarrow/schema.go:
##########
@@ -354,7 +358,15 @@ func fieldToNode(name string, field arrow.Field, props
*parquet.WriterProperties
return fieldToNode(name, arrow.Field{Name: name, Type:
dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata},
props, arrprops)
case arrow.EXTENSION:
- return nil, xerrors.New("not implemented yet")
+ return fieldToNode(name, arrow.Field{
+ Name: name,
+ Type: field.Type.(arrow.ExtensionType).StorageType(),
+ Nullable: field.Nullable,
+ Metadata: arrow.MetadataFrom(map[string]string{
+ ExtensionTypeKeyName:
field.Type.(arrow.ExtensionType).ExtensionName(),
+ }),
+ // Metadata: ,
Review Comment:
what's this comment for?
##########
go/parquet/pqarrow/schema.go:
##########
@@ -354,7 +358,15 @@ func fieldToNode(name string, field arrow.Field, props
*parquet.WriterProperties
return fieldToNode(name, arrow.Field{Name: name, Type:
dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata},
props, arrprops)
case arrow.EXTENSION:
- return nil, xerrors.New("not implemented yet")
+ return fieldToNode(name, arrow.Field{
+ Name: name,
+ Type: field.Type.(arrow.ExtensionType).StorageType(),
+ Nullable: field.Nullable,
+ Metadata: arrow.MetadataFrom(map[string]string{
+ ExtensionTypeKeyName:
field.Type.(arrow.ExtensionType).ExtensionName(),
+ }),
Review Comment:
You also need to add `ExtensionMetadataKeyName:
field.Type.(arrow.ExtensionType).Serialize()`
##########
go/parquet/pqarrow/uuid.go:
##########
@@ -0,0 +1,198 @@
+package pqarrow
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/goccy/go-json"
+
+ "github.com/apache/arrow/go/v12/arrow"
+ "github.com/apache/arrow/go/v12/arrow/array"
+ "github.com/google/uuid"
+)
+
+type UUIDBuilder struct {
+ *array.ExtensionBuilder
+}
+
+func NewUUIDBuilder(bldr *array.ExtensionBuilder) *UUIDBuilder {
+ b := &UUIDBuilder{
+ ExtensionBuilder: bldr,
+ }
+ return b
+}
+
+func (b *UUIDBuilder) Append(v uuid.UUID) {
+ b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:])
+}
+
+func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) {
+ data := make([][]byte, len(v))
+ for i, v := range v {
+ data[i] = v[:]
+ }
+
b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data,
valid)
+}
+
+func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error {
+ t, err := dec.Token()
+ if err != nil {
+ return err
+ }
+
+ var val uuid.UUID
+ switch v := t.(type) {
+ case string:
+ data, err := uuid.Parse(v)
+ if err != nil {
+ return err
+ }
+ val = data
+ case nil:
+ b.AppendNull()
+ return nil
+ default:
+ return &json.UnmarshalTypeError{
+ Value: fmt.Sprint(t),
+ Type: reflect.TypeOf([]byte{}),
+ Offset: dec.InputOffset(),
+ Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16),
+ }
+ }
+
+ if len(val) != 16 {
+ return &json.UnmarshalTypeError{
+ Value: fmt.Sprint(val),
+ Type: reflect.TypeOf([]byte{}),
+ Offset: dec.InputOffset(),
+ Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16),
+ }
+ }
+ b.Append(val)
+ return nil
+}
+
+func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error {
+ for dec.More() {
+ if err := b.UnmarshalOne(dec); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (b *UUIDBuilder) UnmarshalJSON(data []byte) error {
+ dec := json.NewDecoder(bytes.NewReader(data))
+ t, err := dec.Token()
+ if err != nil {
+ return err
+ }
+
+ if delim, ok := t.(json.Delim); !ok || delim != '[' {
+ return fmt.Errorf("fixed size binary builder must unpack from
json array, found %s", delim)
+ }
+
+ return b.Unmarshal(dec)
+}
+
+// UUIDArray is a simple array which is a FixedSizeBinary(16)
+type UUIDArray struct {
+ array.ExtensionArrayBase
+}
+
+func (a UUIDArray) String() string {
+ arr := a.Storage().(*array.FixedSizeBinary)
+ o := new(strings.Builder)
+ o.WriteString("[")
+ for i := 0; i < arr.Len(); i++ {
+ if i > 0 {
+ o.WriteString(" ")
+ }
+ switch {
+ case a.IsNull(i):
+ o.WriteString("(null)")
+ default:
+ uuidStr, err := uuid.FromBytes(arr.Value(i))
+ if err != nil {
+ panic(fmt.Errorf("invalid uuid: %w", err))
+ }
+ fmt.Fprintf(o, "%q", uuidStr)
+ }
+ }
+ o.WriteString("]")
+ return o.String()
+}
+
+func (a *UUIDArray) MarshalJSON() ([]byte, error) {
+ arr := a.Storage().(*array.FixedSizeBinary)
+ vals := make([]interface{}, a.Len())
+ for i := 0; i < a.Len(); i++ {
+ if a.IsValid(i) {
+ uuidStr, err := uuid.FromBytes(arr.Value(i))
+ if err != nil {
+ panic(fmt.Errorf("invalid uuid: %w", err))
+ }
+ vals[i] = uuidStr
+ } else {
+ vals[i] = nil
+ }
+ }
+ return json.Marshal(vals)
+}
+
+func (a *UUIDArray) GetOneForMarshal(i int) interface{} {
+ arr := a.Storage().(*array.FixedSizeBinary)
+ if a.IsValid(i) {
+ uuidObj, err := uuid.FromBytes(arr.Value(i))
+ if err != nil {
+ panic(fmt.Errorf("invalid uuid: %w", err))
+ }
+ return uuidObj
+ }
+ return nil
+}
+
+// UUIDType is a simple extension type that represents a FixedSizeBinary(16)
+// to be used for representing UUIDs
+type UUIDType struct {
+ arrow.ExtensionBase
+}
+
+// NewUUIDType is a convenience function to create an instance of UuidType
+// with the correct storage type
+func NewUUIDType() *UUIDType {
+ return &UUIDType{
+ ExtensionBase: arrow.ExtensionBase{
+ Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}}
+}
+
+// ArrayType returns TypeOf(UuidArray) for constructing uuid arrays
+func (UUIDType) ArrayType() reflect.Type { return reflect.TypeOf(UUIDArray{}) }
+
+func (UUIDType) ExtensionName() string { return "uuid" }
+
+// Serialize returns "uuid-serialized" for testing proper metadata passing
+func (UUIDType) Serialize() string { return "uuid-serialized" }
+
+// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16}
and the data to be
+// "uuid-serialized" in order to correctly create a UuidType for testing
deserialize.
+func (UUIDType) Deserialize(storageType arrow.DataType, data string)
(arrow.ExtensionType, error) {
+ if string(data) != "uuid-serialized" {
+ return nil, fmt.Errorf("type identifier did not match: '%s'",
string(data))
+ }
+ if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth:
16}) {
+ return nil, fmt.Errorf("invalid storage type for UuidType: %s",
storageType.Name())
+ }
+ return NewUUIDType(), nil
+}
+
+// UuidTypes are equal if both are named "uuid"
+func (u UUIDType) ExtensionEquals(other arrow.ExtensionType) bool {
+ return u.ExtensionName() == other.ExtensionName()
+}
+
+func (u UUIDType) NewBuilder(bldr *array.ExtensionBuilder) array.Builder {
+ return NewUUIDBuilder(bldr)
+}
Review Comment:
missing newline at the end of the file
##########
go/parquet/pqarrow/schema.go:
##########
@@ -32,6 +32,10 @@ import (
"golang.org/x/xerrors"
)
+// constants for the extension type metadata keys for the type name and
+// any extension metadata to be passed to deserialize.
+const ExtensionTypeKeyName = "ARROW:extension:name"
+
Review Comment:
you should just use `const ExtensionTypeKeyName = ipc.ExtensionTypeKeyName`
or just use `ipc.ExtensionTypeKeyName` directly since it's already exported
there. No need to hardcode it here.
##########
go/parquet/pqarrow/schema_test.go:
##########
@@ -36,9 +36,13 @@ func TestGetOriginSchemaBase64(t *testing.T) {
origArrSc := arrow.NewSchema([]arrow.Field{
{Name: "f1", Type: arrow.BinaryTypes.String, Metadata: md},
{Name: "f2", Type: arrow.PrimitiveTypes.Int64, Metadata: md},
+ {Name: "uuid", Type: pqarrow.NewUUIDType(), Metadata: md},
}, nil)
arrSerializedSc := flight.SerializeSchema(origArrSc,
memory.DefaultAllocator)
+ if err := arrow.RegisterExtensionType(pqarrow.NewUUIDType()); err !=
nil {
Review Comment:
don't forget to defer the unregister call. `defer
arrow.UnregisterExtensionType("uuid")` it uses the extension name to
unregister, so you can either hardcode it or you can do something like:
```go
uuidType := pqarrow.NewUUIDType()
if err := arrow.RegisterExtensionType(uuidType); err != nil {
t.Fatal(err)
}
defer arrow.UnregisterExtensionType(uuidType.ExtensionName())
```
##########
go/parquet/pqarrow/uuid.go:
##########
@@ -0,0 +1,198 @@
+package pqarrow
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/goccy/go-json"
+
+ "github.com/apache/arrow/go/v12/arrow"
+ "github.com/apache/arrow/go/v12/arrow/array"
+ "github.com/google/uuid"
+)
+
+type UUIDBuilder struct {
Review Comment:
maybe we should shift the existing UUID extension type out of the
`arrow/internal/testing/types` pkg and put it somewhere in the shared
`internal` package that both `arrow` and `parquet` packages can use. that way
we don't have to duplicate the code.
Not saying you have to, just an idea. I'm also fine with having this
separately in here, though since it is only used for testing I'd prefer to keep
this to a `_test` package or a `_test.go` file if you don't shift it to the
shared internal package. Basically I don't want this to be externally
accessible please.
##########
go/parquet/pqarrow/schema.go:
##########
@@ -948,7 +960,15 @@ func getNestedFactory(origin, inferred arrow.DataType)
func(fieldList []arrow.Fi
func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField)
(modified bool, err error) {
nchildren := len(inferred.Children)
switch origin.Type.ID() {
- case arrow.EXTENSION, arrow.SPARSE_UNION, arrow.DENSE_UNION:
+ case arrow.EXTENSION:
+ extType :=
arrow.GetExtensionType(origin.Type.(arrow.ExtensionType).ExtensionName())
+ if extType == nil {
+ err = xerrors.Errorf("arrow: extension type %q not
registered", origin.Type.Name())
+ return
+ }
+ inferred.Field.Type = extType
Review Comment:
why `GetExtensionType` using the name when you already have the extension
type here via `origin.Type` ?
If the extension type wasn't registered, then when the schema was
deserialized you'd only have the storage type anyway and never get into this
spot. In addition if the extension type is a parameterized type, then you're
losing the metadata that was serialized into the `ExtensionMetadataKey`.
I think the better solution here would probably be to do something like:
```go
extType := origin.Type.(arrow.ExtensionType)
modified, err = applyOriginalStorageMetadata(arrow.Field{
Type: extType.StorageType(),
Metadata: arrow.NewMetadata(
[]string{ipc.ExtensionTypeKeyName,
ExtensionMetadataKeyName},
[]string{extType.ExtensionName(),
extType.Serialize()}),
}, inferred)
if err != nil {
return
}
inferred.Field.Type = extType
modified = true
```
Or something to that extent. That way it can properly handle extension types
that have children / are defined in terms of a struct/list/etc. If the
extension type wasn't registered, that's pretty much what would happen anyways
since it would deserialize the schema as the storage type + the metadata.
What do you think? Might also be worthwhile to add a testing extension type
that is parameterized and whose storage type is one of the nested types just to
ensure it is handled correctly. You can find examples of parameteric types in
the `arrow/internal/testing/types/extension_types.go` file.
##########
go/parquet/pqarrow/schema_test.go:
##########
@@ -36,9 +36,13 @@ func TestGetOriginSchemaBase64(t *testing.T) {
origArrSc := arrow.NewSchema([]arrow.Field{
{Name: "f1", Type: arrow.BinaryTypes.String, Metadata: md},
{Name: "f2", Type: arrow.PrimitiveTypes.Int64, Metadata: md},
+ {Name: "uuid", Type: pqarrow.NewUUIDType(), Metadata: md},
Review Comment:
in addition to the schema `ToParquet`/`FromParquet` functions. Please add
some tests for reading and writing a table / record batches which contain
extension types. Such as a simple round trip:
```go
expectedTable := // ... table with extension type column
defer expectedTable.Release()
var buf bytes.Buffer
require.NoError(t, pqarrow.WriteTable(expectedTable, &buf, ....))
actualTable, err := pqarrow.ReadTable(context.Background(),
bytes.NewReader(buf.Bytes()), .....)
require.NoError(t, err)
defer actualTable.Release()
assert.True(t, array.TableEqual(expectedTable, actualTable), ...)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]