This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new 7ce3c03 feat(parquet/pqarrow): read/write variant (#434)
7ce3c03 is described below
commit 7ce3c03fbc268cb67dee53676c48e9bd74be1684
Author: Matt Topol <[email protected]>
AuthorDate: Sun Jul 13 11:12:41 2025 -0400
feat(parquet/pqarrow): read/write variant (#434)
### Rationale for this change
resolves #310
### What changes are included in this PR?
Updating the `pqarrow` package to support full round trip read/write of
Variant values via `arrow/extensions/variant`
### Are these changes tested?
Yes, unit tests are added for both shredded and unshredded variants.
### Are there any user-facing changes?
just the new features.
---
arrow/extensions/variant.go | 82 ++++++++++++++++++++++++++-
arrow/extensions/variant_test.go | 81 +++++++++++++++++++++++++++
parquet/file/record_reader.go | 1 +
parquet/pqarrow/encode_arrow_test.go | 104 +++++++++++++++++++++++++++++++++++
parquet/pqarrow/file_reader.go | 40 +++++++++++++-
parquet/pqarrow/schema.go | 44 +++++++++++----
6 files changed, 339 insertions(+), 13 deletions(-)
diff --git a/arrow/extensions/variant.go b/arrow/extensions/variant.go
index fbef4a6..fe97f24 100644
--- a/arrow/extensions/variant.go
+++ b/arrow/extensions/variant.go
@@ -62,6 +62,83 @@ func NewDefaultVariantType() *VariantType {
return vt
}
+func createShreddedField(dt arrow.DataType) arrow.DataType {
+ switch t := dt.(type) {
+ case arrow.ListLikeType:
+ return arrow.ListOfNonNullable(arrow.StructOf(
+ arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
createShreddedField(t.Elem()), Nullable: true},
+ ))
+ case *arrow.StructType:
+ fields := make([]arrow.Field, 0, t.NumFields())
+ for i := range t.NumFields() {
+ f := t.Field(i)
+ fields = append(fields, arrow.Field{
+ Name: f.Name,
+ Type: arrow.StructOf(arrow.Field{
+ Name: "value",
+ Type: arrow.BinaryTypes.Binary,
+ Nullable: true,
+ }, arrow.Field{
+ Name: "typed_value",
+ Type: createShreddedField(f.Type),
+ Nullable: true,
+ }),
+ Nullable: false,
+ Metadata: f.Metadata,
+ })
+ }
+ return arrow.StructOf(fields...)
+ default:
+ return dt
+ }
+}
+
+// NewShreddedVariantType creates a new VariantType extension type using the
provided
+// type to define a shredded schema by setting the `typed_value` field
accordingly and
+// properly constructing the shredded fields for structs, lists and so on.
+//
+// For example:
+//
+// NewShreddedVariantType(arrow.StructOf(
+// arrow.Field{Name: "latitude", Type: arrow.PrimitiveTypes.Float64},
+// arrow.Field{Name: "longitude", Type:
arrow.PrimitiveTypes.Float32}))
+//
+// Will create a variant type with the following structure:
+//
+// arrow.StructOf(
+// arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary,
Nullable: false},
+// arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary,
Nullable: true},
+// arrow.Field{Name: "typed_value", Type: arrow.StructOf(
+// arrow.Field{Name: "latitude", Type: arrow.StructOf(
+// arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary,
Nullable: true},
+// arrow.Field{Name: "typed_value", Type:
arrow.PrimitiveTypes.Float64, Nullable: true}),
+// Nullable: false},
+// arrow.Field{Name: "longitude", Type: arrow.StructOf(
+// arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary,
Nullable: true},
+// arrow.Field{Name: "typed_value", Type:
arrow.PrimitiveTypes.Float32, Nullable: true}),
+// Nullable: false},
+// ), Nullable: true})
+//
+// This is intended to be a convenient way to create a shredded variant type
from a definition
+// of the fields to shred. If the provided data type is nil, it will create a
default
+// variant type.
+func NewShreddedVariantType(dt arrow.DataType) *VariantType {
+ if dt == nil {
+ return NewDefaultVariantType()
+ }
+
+ vt, _ := NewVariantType(arrow.StructOf(
+ arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary,
Nullable: false},
+ arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary,
Nullable: true},
+ arrow.Field{
+ Name: "typed_value",
+ Type: createShreddedField(dt),
+ Nullable: true,
+ }))
+ return vt
+}
+
// NewVariantType creates a new variant type based on the provided storage
type.
//
// The rules for a variant storage type are:
@@ -1480,8 +1557,9 @@ type shreddedObjBuilder struct {
}
func (b *shreddedObjBuilder) AppendMissing() {
- b.structBldr.Append(true)
+ b.structBldr.AppendValues([]bool{false})
for _, fieldBldr := range b.fieldBuilders {
+ fieldBldr.structBldr.Append(true)
fieldBldr.valueBldr.AppendNull()
fieldBldr.typedBldr.AppendMissing()
}
@@ -1489,7 +1567,7 @@ func (b *shreddedObjBuilder) AppendMissing() {
func (b *shreddedObjBuilder) tryTyped(v variant.Value) (residual []byte) {
if v.Type() != variant.Object {
- b.structBldr.AppendNull()
+ b.AppendMissing()
return v.Bytes()
}
diff --git a/arrow/extensions/variant_test.go b/arrow/extensions/variant_test.go
index 9a1c05f..6e539ee 100644
--- a/arrow/extensions/variant_test.go
+++ b/arrow/extensions/variant_test.go
@@ -1574,3 +1574,84 @@ func TestVariantBuilderUnmarshalJSON(t *testing.T) {
assert.Equal(t, int8(5), innerVal2.Value())
})
}
+
+func TestNewSimpleShreddedVariantType(t *testing.T) {
+ assert.True(t, arrow.TypeEqual(extensions.NewDefaultVariantType(),
+ extensions.NewShreddedVariantType(nil)))
+
+ vt := extensions.NewShreddedVariantType(arrow.PrimitiveTypes.Float32)
+ s := arrow.StructOf(
+ arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
+ arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary,
Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
arrow.PrimitiveTypes.Float32, Nullable: true})
+
+ assert.Truef(t, arrow.TypeEqual(vt.Storage, s), "expected %s, got %s",
s, vt.Storage)
+}
+
+func TestNewShreddedVariantType(t *testing.T) {
+ vt := extensions.NewShreddedVariantType(arrow.StructOf(arrow.Field{
+ Name: "event_type",
+ Type: arrow.BinaryTypes.String,
+ }, arrow.Field{
+ Name: "event_ts",
+ Type: arrow.FixedWidthTypes.Timestamp_us,
+ }))
+
+ assert.NotNil(t, vt)
+ s := arrow.StructOf(
+ arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
+ arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary,
Nullable: true},
+ arrow.Field{Name: "typed_value", Type: arrow.StructOf(
+ arrow.Field{Name: "event_type", Type: arrow.StructOf(
+ arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
arrow.BinaryTypes.String, Nullable: true},
+ )},
+ arrow.Field{Name: "event_ts", Type: arrow.StructOf(
+ arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
arrow.FixedWidthTypes.Timestamp_us, Nullable: true},
+ )},
+ ), Nullable: true})
+
+ assert.Truef(t, arrow.TypeEqual(vt.Storage, s), "expected %s, got %s",
s, vt.Storage)
+}
+
+func TestShreddedVariantNested(t *testing.T) {
+ vt := extensions.NewShreddedVariantType(arrow.StructOf(
+ arrow.Field{Name: "strval", Type: arrow.BinaryTypes.String},
+ arrow.Field{Name: "bool", Type: arrow.FixedWidthTypes.Boolean},
+ arrow.Field{Name: "location", Type: arrow.ListOf(arrow.StructOf(
+ arrow.Field{Name: "latitude", Type:
arrow.PrimitiveTypes.Float64},
+ arrow.Field{Name: "longitude", Type:
arrow.PrimitiveTypes.Float32},
+ ))}))
+
+ s := arrow.StructOf(
+ arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
+ arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary,
Nullable: true},
+ arrow.Field{Name: "typed_value", Type: arrow.StructOf(
+ arrow.Field{Name: "strval", Type: arrow.StructOf(
+ arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
arrow.BinaryTypes.String, Nullable: true},
+ )},
+ arrow.Field{Name: "bool", Type: arrow.StructOf(
+ arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
arrow.FixedWidthTypes.Boolean, Nullable: true},
+ )},
+ arrow.Field{Name: "location", Type: arrow.StructOf(
+ arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
arrow.ListOfNonNullable(arrow.StructOf(
+ arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name: "typed_value", Type:
arrow.StructOf(
+ arrow.Field{Name: "latitude",
Type: arrow.StructOf(
+ arrow.Field{Name:
"value", Type: arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name:
"typed_value", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
+ )},
+ arrow.Field{Name: "longitude",
Type: arrow.StructOf(
+ arrow.Field{Name:
"value", Type: arrow.BinaryTypes.Binary, Nullable: true},
+ arrow.Field{Name:
"typed_value", Type: arrow.PrimitiveTypes.Float32, Nullable: true},
+ )},
+ ), Nullable: true},
+ )), Nullable: true})},
+ ), Nullable: true})
+
+ assert.Truef(t, arrow.TypeEqual(vt.Storage, s), "expected %s, got %s",
s, vt.Storage)
+}
diff --git a/parquet/file/record_reader.go b/parquet/file/record_reader.go
index 81ec0af..a21e066 100644
--- a/parquet/file/record_reader.go
+++ b/parquet/file/record_reader.go
@@ -555,6 +555,7 @@ func (rr *recordReader) ReadRecordData(numRecords int64)
(int64, error) {
// no repetition levels, skip delimiting logic. each level
// represents null or not null entry
recordsRead = utils.Min(rr.levelsWritten-rr.levelsPos,
numRecords)
+ valuesToRead = recordsRead
// this is advanced by delimitRecords which we skipped
rr.levelsPos += recordsRead
} else {
diff --git a/parquet/pqarrow/encode_arrow_test.go
b/parquet/pqarrow/encode_arrow_test.go
index 0a2edab..61bc263 100644
--- a/parquet/pqarrow/encode_arrow_test.go
+++ b/parquet/pqarrow/encode_arrow_test.go
@@ -2314,3 +2314,107 @@ func TestEmptyListDeltaBinaryPacked(t *testing.T) {
assert.True(t, schema.Equal(tbl.Schema()))
assert.EqualValues(t, 1, tbl.NumRows())
}
+
+func TestReadWriteNonShreddedVariant(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ bldr := extensions.NewVariantBuilder(mem,
extensions.NewDefaultVariantType())
+ defer bldr.Release()
+
+ jsonData := `[
+ 42,
+ "text",
+ [1, 2, 3],
+ {"name": "Alice"},
+ [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item
2"}],
+ {"items": [1, "two", true], "metadata": {"created":
"2025-01-01"}},
+ null
+ ]`
+
+ err := bldr.UnmarshalJSON([]byte(jsonData))
+ require.NoError(t, err)
+
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ rec := array.NewRecord(arrow.NewSchema([]arrow.Field{
+ {Name: "variant", Type: arr.DataType(), Nullable: true},
+ }, nil), []arrow.Array{arr}, -1)
+
+ var buf bytes.Buffer
+ wr, err := pqarrow.NewFileWriter(rec.Schema(), &buf, nil,
+ pqarrow.DefaultWriterProps())
+ require.NoError(t, err)
+
+ require.NoError(t, wr.Write(rec))
+ rec.Release()
+ wr.Close()
+
+ rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+ require.NoError(t, err)
+ reader, err := pqarrow.NewFileReader(rdr,
pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+ require.NoError(t, err)
+ defer rdr.Close()
+
+ tbl, err := reader.ReadTable(context.Background())
+ require.NoError(t, err)
+ defer tbl.Release()
+
+ assert.True(t, array.Equal(arr, tbl.Column(0).Data().Chunk(0)))
+}
+
+func TestReadWriteShreddedVariant(t *testing.T) {
+ vt := extensions.NewShreddedVariantType(arrow.StructOf(
+ arrow.Field{Name: "event_type", Type: arrow.BinaryTypes.String},
+ arrow.Field{Name: "event_ts", Type:
arrow.FixedWidthTypes.Timestamp_us}))
+
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ bldr := vt.NewBuilder(mem)
+ defer bldr.Release()
+
+ jsonData := `[
+ {"event_type": "noop", "event_ts": "1970-01-21
00:29:54.114937Z"},
+ 42,
+ {"event_type": "text", "event_ts": "1970-01-21
00:29:54.954163Z"},
+ {"event_type": "list", "event_ts": "1970-01-21
00:29:54.240241Z"},
+ "text",
+ {"event_type": "object", "event_ts": "1970-01-21
00:29:54.146402Z"},
+ null
+ ]`
+
+ err := bldr.UnmarshalJSON([]byte(jsonData))
+ require.NoError(t, err)
+
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ rec := array.NewRecord(arrow.NewSchema([]arrow.Field{
+ {Name: "variant", Type: arr.DataType(), Nullable: true},
+ }, nil), []arrow.Array{arr}, -1)
+
+ var buf bytes.Buffer
+ wr, err := pqarrow.NewFileWriter(rec.Schema(), &buf,
+
parquet.NewWriterProperties(parquet.WithDictionaryDefault(false)),
+ pqarrow.DefaultWriterProps())
+ require.NoError(t, err)
+
+ require.NoError(t, wr.Write(rec))
+ rec.Release()
+ wr.Close()
+
+ rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+ require.NoError(t, err)
+ reader, err := pqarrow.NewFileReader(rdr,
pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+ require.NoError(t, err)
+ defer rdr.Close()
+
+ tbl, err := reader.ReadTable(context.Background())
+ require.NoError(t, err)
+ defer tbl.Release()
+
+ assert.Truef(t, array.Equal(arr, tbl.Column(0).Data().Chunk(0)),
+ "expected: %s\ngot: %s", arr, tbl.Column(0).Data().Chunk(0))
+}
diff --git a/parquet/pqarrow/file_reader.go b/parquet/pqarrow/file_reader.go
index b064107..8fb114c 100644
--- a/parquet/pqarrow/file_reader.go
+++ b/parquet/pqarrow/file_reader.go
@@ -111,6 +111,37 @@ func (fr *FileReader) Schema() (*arrow.Schema, error) {
return FromParquet(fr.rdr.MetaData().Schema, &fr.Props,
fr.rdr.MetaData().KeyValueMetadata())
}
+type extensionReader struct {
+ colReaderImpl
+
+ fieldWithExt arrow.Field
+}
+
+func (er *extensionReader) Field() *arrow.Field {
+ return &er.fieldWithExt
+}
+
+func (er *extensionReader) BuildArray(boundedLen int64) (*arrow.Chunked,
error) {
+ if er.colReaderImpl == nil {
+ return nil, errors.New("extension reader has no underlying
column reader implementation")
+ }
+
+ chkd, err := er.colReaderImpl.BuildArray(boundedLen)
+ if err != nil {
+ return nil, err
+ }
+ defer chkd.Release()
+
+ extType := er.fieldWithExt.Type.(arrow.ExtensionType)
+
+ newChunks := make([]arrow.Array, len(chkd.Chunks()))
+ for i, c := range chkd.Chunks() {
+ newChunks[i] = array.NewExtensionArrayWithStorage(extType, c)
+ }
+
+ return arrow.NewChunked(extType, newChunks), nil
+}
+
type colReaderImpl interface {
LoadBatch(nrecs int64) error
BuildArray(boundedLen int64) (*arrow.Chunked, error)
@@ -517,7 +548,14 @@ func (fr *FileReader) getReader(ctx context.Context, field
*SchemaField, arrowFi
switch arrowField.Type.ID() {
case arrow.EXTENSION:
- return nil, xerrors.New("extension type not implemented")
+ storageField := arrowField
+ storageField.Type =
arrowField.Type.(arrow.ExtensionType).StorageType()
+ storageReader, err := fr.getReader(ctx, field, storageField)
+ if err != nil {
+ return nil, err
+ }
+
+ return &ColumnReader{&extensionReader{colReaderImpl:
storageReader, fieldWithExt: arrowField}}, nil
case arrow.STRUCT:
childReaders := make([]*ColumnReader, len(field.Children))
diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go
index 17603b9..2c0e70b 100644
--- a/parquet/pqarrow/schema.go
+++ b/parquet/pqarrow/schema.go
@@ -18,7 +18,6 @@ package pqarrow
import (
"encoding/base64"
- "errors"
"fmt"
"math"
"strconv"
@@ -243,25 +242,25 @@ func repFromNullable(isnullable bool) parquet.Repetition {
}
func variantToNode(t *extensions.VariantType, field arrow.Field, props
*parquet.WriterProperties, arrProps ArrowWriterProperties) (schema.Node, error)
{
- metadataNode, err := fieldToNode("metadata", t.Metadata(), props,
arrProps)
+ fields := make(schema.FieldList, 2, 3)
+ var err error
+
+ fields[0], err = fieldToNode("metadata", t.Metadata(), props, arrProps)
if err != nil {
return nil, err
}
- valueNode, err := fieldToNode("value", t.Value(), props, arrProps)
+ fields[1], err = fieldToNode("value", t.Value(), props, arrProps)
if err != nil {
return nil, err
}
- fields := schema.FieldList{metadataNode, valueNode}
-
- typedField := t.TypedValue()
- if typedField.Type != nil {
- typedNode, err := fieldToNode("typed_value", typedField, props,
arrProps)
+ if typed := t.TypedValue(); typed.Type != nil {
+ typedValue, err := fieldToNode("typed_value", typed, props,
arrProps)
if err != nil {
return nil, err
}
- fields = append(fields, typedNode)
+ fields = append(fields, typedValue)
}
return schema.NewGroupNodeLogical(field.Name,
repFromNullable(field.Nullable),
@@ -868,9 +867,34 @@ func variantToSchemaField(n *schema.GroupNode,
currentLevels file.LevelInfo, ctx
switch n.NumFields() {
case 2, 3:
default:
- return errors.New("VARIANT group must have exactly 2 or 3
children")
+ return fmt.Errorf("VARIANT group must have exactly 2 or 3
children, not %d", n.NumFields())
}
+ if n.RepetitionType() == parquet.Repetitions.Repeated {
+ // list of variants
+ out.Children = make([]SchemaField, 1)
+ repeatedAncestorDef := currentLevels.IncrementRepeated()
+ if err := groupToStructField(n, currentLevels, ctx,
&out.Children[0]); err != nil {
+ return err
+ }
+
+ storageType := out.Children[0].Field.Type
+ elemType, err := extensions.NewVariantType(storageType)
+ if err != nil {
+ return err
+ }
+
+ out.Children[0].Field.Type = elemType
+ out.Field = &arrow.Field{Name: n.Name(), Type:
arrow.ListOfField(*out.Children[0].Field), Nullable: true,
+ Metadata: createFieldMeta(int(n.FieldID()))}
+ ctx.LinkParent(&out.Children[0], out)
+ out.LevelInfo = currentLevels
+ out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
+ return nil
+ }
+
+ currentLevels.Increment(n)
+
var err error
if err = groupToStructField(n, currentLevels, ctx, out); err != nil {
return err