This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new a9704994 ci(parquet/pqarrow): integration tests for reading shredded
variants (#455)
a9704994 is described below
commit a97049945a08d36a494de637cffb100a02f4bc5a
Author: Matt Topol <[email protected]>
AuthorDate: Wed Aug 27 11:04:26 2025 -0400
ci(parquet/pqarrow): integration tests for reading shredded variants (#455)
### Rationale for this change
Testing out the variant implementation here against Parquet java using
the test cases generated in
https://github.com/apache/parquet-testing/pull/90/files. Overall, it
confirms that our implementation is generally compatible for reading
parquet files written by parquet-java with some caveats.
### What changes are included in this PR?
New testing suite in `parquet/pqarrow/variant_test.go` which uses the
test cases defined in parquet-testing and attempts to read the parquet
files and compares the resulting variants against the expected ones.
Some issues were found that I believe are issues with Parquet-java and
the test cases rather than issues with the Go implementation, as such
discussion is needed for the following:
* The parquet test files are missing the Logical Variant Type
annotation. Currently I've worked around that for testing purposes, but
not in a way that can be merged or that is sustainable. As such the
files need to be re-generated with the Variant Logical Type annotation
before these tests can be enabled.
* Several test cases test variations on situations where the `value`
column is missing. Based on my reading of the
[spec](https://github.com/apache/parquet-format/blob/master/VariantShredding.md)
this seems to be an invalid scenario. The specific case is that the spec
states the `typed_value` field may be omitted when not shredding
elements as a specific type, but says nothing about allowing omission of
the `value` field. Currently, the Go implementation will error if this
field is missing as per my reading of the spec, meaning those test cases
fail.
* Test case 43 `testPartiallyShreddedObjectMissingFieldConflict` seems
to have a conflict between what is expected and what in the spec. The
`b` field exists within the `value` field, while also being a shredded
field, the test appears to assume the data in the `value` field would be
ignored, but the
[spec](https://github.com/apache/parquet-format/blob/master/VariantShredding.md#objects)
says that `value` **must never** contain fields represented by the
shredded fields. This needs clarification on the desired behavior and
result.
* Test case 84, `testShreddedObjectWithOptionalFieldStructs` tests the
schenario where the shredded fields of an object are listed as
`optional` in the schema, but the spec states that they *must* be
`required`. Thus, the Go implementation errors on this test as the spec
says this is an error. Clarification is needed on if this is a valid
test case.
* Test case 38 `testShreddedObjectMissingTypedValue` tests the case
where the `typed_value` field is missing, this is allowed by the spec
except that the spec states that in this scenario the `value` field
**must** be `required`. The test case uses `optional` in this scenario
causing the Go implementation to fail. Clarification is needed here.
* Test case 125, `testPartiallyShreddedObjectFieldConflict` again tests
the case of a field existing in both the `value` and the shredded column
which the spec states is invalid and will lead to inconsistent results.
Thus it is not valid to have this test case assert a specific result
according to the spec unless the spec is amended to state that the
shredded field takes precedence in this case.
* One thing that makes the tests a bit difficult is that when we
un-shred back into variants, the current variant code in some libraries
will automatically downcast to the smallest viable precision
(downcasting an int32 into an int16 for example if it fits). This is
worked around by testing the *values* rather than the types, but is
worth mentioning. Particularly in the case of decimal values
* A couple error test cases verify that particular types are **not**
supported such as UINT32 or Fixed Len Byte Array(4), nothing in the spec
however says that an implementation couldn't just upcast a uint32 ->
int64 or treat a fixed len byte array shredded column as a binary
column. So is it meaningful to explicitly error on those cases rather
than allow them since they are trivially convertable to valid variant
types?
---
arrow-testing | 2 +-
arrow/extensions/variant.go | 157 +++++++++++++------
arrow/extensions/variant_test.go | 24 +--
parquet-testing | 2 +-
parquet/pqarrow/encode_arrow.go | 2 +-
parquet/pqarrow/schema.go | 20 ++-
parquet/pqarrow/schema_test.go | 4 +-
parquet/pqarrow/variant_test.go | 326 +++++++++++++++++++++++++++++++++++++++
parquet/schema/logical_types.go | 3 +-
parquet/variant/builder.go | 16 +-
parquet/variant/builder_test.go | 4 +-
parquet/variant/variant.go | 5 +-
parquet/variant/variant_test.go | 4 +-
13 files changed, 474 insertions(+), 95 deletions(-)
diff --git a/arrow-testing b/arrow-testing
index d2a13712..6a7b02fa 160000
--- a/arrow-testing
+++ b/arrow-testing
@@ -1 +1 @@
-Subproject commit d2a13712303498963395318a4eb42872e66aead7
+Subproject commit 6a7b02fac93d8addbcdbb213264e58bfdc3068e4
diff --git a/arrow/extensions/variant.go b/arrow/extensions/variant.go
index fe97f247..659f571c 100644
--- a/arrow/extensions/variant.go
+++ b/arrow/extensions/variant.go
@@ -18,6 +18,7 @@ package extensions
import (
"bytes"
+ "errors"
"fmt"
"math"
"reflect"
@@ -171,21 +172,23 @@ func NewVariantType(storage arrow.DataType)
(*VariantType, error) {
return nil, fmt.Errorf("%w: missing non-nullable field
'metadata' in variant storage type %s", arrow.ErrInvalid, storage)
}
- if valueFieldIdx, ok = s.FieldIdx("value"); !ok {
- return nil, fmt.Errorf("%w: missing non-nullable field 'value'
in variant storage type %s", arrow.ErrInvalid, storage)
+ var valueOk, typedValueOk bool
+ valueFieldIdx, valueOk = s.FieldIdx("value")
+ typedValueFieldIdx, typedValueOk = s.FieldIdx("typed_value")
+
+ if !valueOk && !typedValueOk {
+ return nil, fmt.Errorf("%w: there must be at least one of
'value' or 'typed_value' fields in variant storage type %s", arrow.ErrInvalid,
storage)
}
- if s.NumFields() > 3 {
- return nil, fmt.Errorf("%w: too many fields in variant storage
type %s, expected 2 or 3", arrow.ErrInvalid, storage)
+ if s.NumFields() == 3 && (!valueOk || !typedValueOk) {
+ return nil, fmt.Errorf("%w: has 3 fields, but missing one of
'value' or 'typed_value' fields, %s", arrow.ErrInvalid, storage)
}
- if s.NumFields() == 3 {
- if typedValueFieldIdx, ok = s.FieldIdx("typed_value"); !ok {
- return nil, fmt.Errorf("%w: has 3 fields, but missing
'typed_value' field, %s", arrow.ErrInvalid, storage)
- }
+ if s.NumFields() > 3 {
+ return nil, fmt.Errorf("%w: too many fields in variant storage
type %s, expected 2 or 3", arrow.ErrInvalid, storage)
}
- mdField, valField := s.Field(metadataFieldIdx), s.Field(valueFieldIdx)
+ mdField := s.Field(metadataFieldIdx)
if mdField.Nullable {
return nil, fmt.Errorf("%w: metadata field must be non-nullable
binary type, got %s", arrow.ErrInvalid, mdField.Type)
}
@@ -196,11 +199,14 @@ func NewVariantType(storage arrow.DataType)
(*VariantType, error) {
}
}
- if !isBinary(valField.Type) || (valField.Nullable && typedValueFieldIdx
== -1) {
- return nil, fmt.Errorf("%w: value field must be non-nullable
binary type, got %s", arrow.ErrInvalid, valField.Type)
+ if valueOk {
+ valField := s.Field(valueFieldIdx)
+ if !isBinary(valField.Type) {
+ return nil, fmt.Errorf("%w: value field must be binary
type, got %s", arrow.ErrInvalid, valField.Type)
+ }
}
- if typedValueFieldIdx == -1 {
+ if !typedValueOk {
return &VariantType{
ExtensionBase: arrow.ExtensionBase{Storage:
storage},
metadataFieldIdx: metadataFieldIdx,
@@ -209,17 +215,17 @@ func NewVariantType(storage arrow.DataType)
(*VariantType, error) {
}, nil
}
- valueField := s.Field(valueFieldIdx)
- if !valueField.Nullable {
- return nil, fmt.Errorf("%w: value field must be nullable if
typed_value is present, got %s", arrow.ErrInvalid, valueField.Type)
- }
-
typedValueField := s.Field(typedValueFieldIdx)
if !typedValueField.Nullable {
return nil, fmt.Errorf("%w: typed_value field must be nullable,
got %s", arrow.ErrInvalid, typedValueField.Type)
}
- if nt, ok := typedValueField.Type.(arrow.NestedType); ok {
+ dt := typedValueField.Type
+ if dt.ID() == arrow.EXTENSION {
+ dt = dt.(arrow.ExtensionType).StorageType()
+ }
+
+ if nt, ok := dt.(arrow.NestedType); ok {
if !validNestedType(nt) {
return nil, fmt.Errorf("%w: typed_value field must be a
valid nested type, got %s", arrow.ErrInvalid, typedValueField.Type)
}
@@ -242,6 +248,9 @@ func (v *VariantType) Metadata() arrow.Field {
}
func (v *VariantType) Value() arrow.Field {
+ if v.valueFieldIdx == -1 {
+ return arrow.Field{}
+ }
return v.StorageType().(*arrow.StructType).Field(v.valueFieldIdx)
}
@@ -286,7 +295,7 @@ func validStruct(s *arrow.StructType) bool {
switch s.NumFields() {
case 1:
f := s.Field(0)
- return f.Name == "value" && !f.Nullable && isBinary(f.Type)
+ return (f.Name == "value" && isBinary(f.Type)) || f.Name ==
"typed_value"
case 2:
valField, ok := s.FieldByName("value")
if !ok || !valField.Nullable || !isBinary(valField.Type) {
@@ -365,8 +374,6 @@ func (v *VariantArray) initReader() {
vt := v.ExtensionType().(*VariantType)
st := v.Storage().(*array.Struct)
metaField := st.Field(vt.metadataFieldIdx)
- valueField := st.Field(vt.valueFieldIdx)
-
metadata, ok := metaField.(arrow.TypedArray[[]byte])
if !ok {
// we already validated that if the metadata field
isn't a binary
@@ -374,24 +381,30 @@ func (v *VariantArray) initReader() {
metadata, _ =
array.NewDictWrapper[[]byte](metaField.(*array.Dictionary))
}
- if vt.typedValueFieldIdx == -1 {
+ var value arrow.TypedArray[[]byte]
+ if vt.valueFieldIdx != -1 {
+ valueField := st.Field(vt.valueFieldIdx)
+ value = valueField.(arrow.TypedArray[[]byte])
+ }
+
+ var ivreader typedValReader
+ var err error
+ if vt.typedValueFieldIdx != -1 {
+ ivreader, err =
getReader(st.Field(vt.typedValueFieldIdx))
+ if err != nil {
+ v.rdrErr = err
+ return
+ }
+ v.rdr = &shreddedVariantReader{
+ metadata: metadata,
+ value: value,
+ typedValue: ivreader,
+ }
+ } else {
v.rdr = &basicVariantReader{
metadata: metadata,
- value: valueField.(arrow.TypedArray[[]byte]),
+ value: value,
}
- return
- }
-
- ivreader, err := getReader(st.Field(vt.typedValueFieldIdx))
- if err != nil {
- v.rdrErr = err
- return
- }
-
- v.rdr = &shreddedVariantReader{
- metadata: metadata,
- value: valueField.(arrow.TypedArray[[]byte]),
- typedValue: ivreader,
}
})
}
@@ -419,6 +432,9 @@ func (v *VariantArray) Metadata() arrow.TypedArray[[]byte] {
// value of null).
func (v *VariantArray) UntypedValues() arrow.TypedArray[[]byte] {
vt := v.ExtensionType().(*VariantType)
+ if vt.valueFieldIdx == -1 {
+ return nil
+ }
return
v.Storage().(*array.Struct).Field(vt.valueFieldIdx).(arrow.TypedArray[[]byte])
}
@@ -451,7 +467,6 @@ func (v *VariantArray) IsNull(i int) bool {
}
vt := v.ExtensionType().(*VariantType)
- valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx)
if vt.typedValueFieldIdx != -1 {
typedArr :=
v.Storage().(*array.Struct).Field(vt.typedValueFieldIdx)
if !typedArr.IsNull(i) {
@@ -459,6 +474,7 @@ func (v *VariantArray) IsNull(i int) bool {
}
}
+ valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx)
b := valArr.(arrow.TypedArray[[]byte]).Value(i)
return len(b) == 1 && b[0] == 0 // variant null
}
@@ -747,9 +763,20 @@ func getReader(typedArr arrow.Array) (typedValReader,
error) {
childType := child.DataType().(*arrow.StructType)
valueIdx, _ := childType.FieldIdx("value")
- valueArr :=
child.Field(valueIdx).(arrow.TypedArray[[]byte])
+ var valueArr arrow.TypedArray[[]byte]
+ if valueIdx != -1 {
+ valueArr =
child.Field(valueIdx).(arrow.TypedArray[[]byte])
+ }
+
+ typedValueIdx, exists :=
childType.FieldIdx("typed_value")
+ if !exists {
+ fieldReaders[fieldList[i].Name] =
fieldReaderPair{
+ values: valueArr,
+ typedVal: nil,
+ }
+ continue
+ }
- typedValueIdx, _ := childType.FieldIdx("typed_value")
typedRdr, err := getReader(child.Field(typedValueIdx))
if err != nil {
return nil, fmt.Errorf("error getting typed
value reader for field %s: %w", fieldList[i].Name, err)
@@ -768,13 +795,22 @@ func getReader(typedArr arrow.Array) (typedValReader,
error) {
case array.ListLike:
listValues := arr.ListValues().(*array.Struct)
elemType := listValues.DataType().(*arrow.StructType)
+
+ var valueArr arrow.TypedArray[[]byte]
+ var typedRdr typedValReader
+
valueIdx, _ := elemType.FieldIdx("value")
- valueArr :=
listValues.Field(valueIdx).(arrow.TypedArray[[]byte])
+ if valueIdx != -1 {
+ valueArr =
listValues.Field(valueIdx).(arrow.TypedArray[[]byte])
+ }
typedValueIdx, _ := elemType.FieldIdx("typed_value")
- typedRdr, err := getReader(listValues.Field(typedValueIdx))
- if err != nil {
- return nil, fmt.Errorf("error getting typed value
reader: %w", err)
+ if typedValueIdx != -1 {
+ var err error
+ typedRdr, err =
getReader(listValues.Field(typedValueIdx))
+ if err != nil {
+ return nil, fmt.Errorf("error getting typed
value reader: %w", err)
+ }
}
return &typedListReader{
@@ -796,6 +832,7 @@ func constructVariant(b *variant.Builder, meta
variant.Metadata, value []byte, t
switch v := typedVal.(type) {
case nil:
if len(value) == 0 {
+ b.AppendNull()
return nil
}
@@ -846,6 +883,9 @@ func constructVariant(b *variant.Builder, meta
variant.Metadata, value []byte, t
return b.FinishArray(arrstart, elems)
case []byte:
+ if len(value) > 0 {
+ return errors.New("invalid variant, conflicting value
and typed_value")
+ }
return b.UnsafeAppendEncoded(v)
default:
return fmt.Errorf("%w: unsupported typed value type %T for
variant", arrow.ErrInvalid, v)
@@ -876,14 +916,24 @@ func (v *typedObjReader) Value(meta variant.Metadata, i
int) (any, error) {
return nil, nil
}
+ var err error
result := make(map[string]typedPair)
for name, rdr := range v.fieldRdrs {
- typedValue, err := rdr.typedVal.Value(meta, i)
- if err != nil {
- return nil, fmt.Errorf("error reading typed value for
field %s at index %d: %w", name, i, err)
+ var typedValue any
+ if rdr.typedVal != nil {
+ typedValue, err = rdr.typedVal.Value(meta, i)
+ if err != nil {
+ return nil, fmt.Errorf("error reading typed
value for field %s at index %d: %w", name, i, err)
+ }
}
+
+ var val []byte
+ if rdr.values != nil {
+ val = rdr.values.Value(i)
+ }
+
result[name] = typedPair{
- Value: rdr.values.Value(i),
+ Value: val,
TypedValue: typedValue,
}
}
@@ -913,7 +963,11 @@ func (v *typedListReader) Value(meta variant.Metadata, i
int) (any, error) {
result := make([]typedPair, 0, end-start)
for j := start; j < end; j++ {
- val := v.valueArr.Value(int(j))
+ var val []byte
+ if v.valueArr != nil {
+ val = v.valueArr.Value(int(j))
+ }
+
typedValue, err := v.typedVal.Value(meta, int(j))
if err != nil {
return nil, fmt.Errorf("error reading typed value at
index %d: %w", j, err)
@@ -956,12 +1010,17 @@ func (v *shreddedVariantReader) Value(i int)
(variant.Value, error) {
}
b := variant.NewBuilderFromMeta(meta)
+ b.SetAllowDuplicates(true)
typed, err := v.typedValue.Value(meta, i)
if err != nil {
return variant.NullValue, fmt.Errorf("error reading typed value
at index %d: %w", i, err)
}
- if err := constructVariant(b, meta, v.value.Value(i), typed); err !=
nil {
+ var value []byte
+ if v.value != nil {
+ value = v.value.Value(i)
+ }
+ if err := constructVariant(b, meta, value, typed); err != nil {
return variant.NullValue, fmt.Errorf("error constructing
variant at index %d: %w", i, err)
}
return b.Build()
diff --git a/arrow/extensions/variant_test.go b/arrow/extensions/variant_test.go
index 6e539ee5..925d0621 100644
--- a/arrow/extensions/variant_test.go
+++ b/arrow/extensions/variant_test.go
@@ -61,21 +61,18 @@ func TestVariantExtensionType(t *testing.T) {
expectedErr string
}{
{arrow.StructOf(arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary}),
- "missing non-nullable field 'value'"},
+ "there must be at least one of 'value' or 'typed_value'
fields in variant storage type"},
{arrow.StructOf(arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary}), "missing non-nullable field 'metadata'"},
{arrow.StructOf(arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary},
arrow.Field{Name: "value", Type:
arrow.PrimitiveTypes.Int32}),
- "value field must be non-nullable binary type, got
int32"},
+ "value field must be binary type, got int32"},
{arrow.StructOf(arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary},
arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary},
arrow.Field{Name: "extra", Type:
arrow.BinaryTypes.Binary}),
- "has 3 fields, but missing 'typed_value' field"},
+ "has 3 fields, but missing one of 'value' or
'typed_value' field"},
{arrow.StructOf(arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary, Nullable: true},
arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: false}),
"metadata field must be non-nullable binary type"},
- {arrow.StructOf(arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary, Nullable: false},
- arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true}),
- "value field must be non-nullable binary type"},
{arrow.FixedWidthTypes.Boolean, "bad storage type bool for
variant type"},
{arrow.StructOf(
arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary, Nullable: false},
@@ -86,16 +83,6 @@ func TestVariantExtensionType(t *testing.T) {
arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.String, Nullable: false},
arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: false}),
"metadata field must be non-nullable binary type, got
utf8"},
- {arrow.StructOf(
- arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary, Nullable: false},
- arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: false},
- arrow.Field{Name: "typed_value", Type:
arrow.BinaryTypes.String, Nullable: true}),
- "value field must be nullable if typed_value is
present"},
- {arrow.StructOf(
- arrow.Field{Name: "metadata", Type:
arrow.BinaryTypes.Binary, Nullable: false},
- arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
- arrow.Field{Name: "typed_value", Type:
arrow.BinaryTypes.String, Nullable: false}),
- "typed_value field must be nullable"},
}
for _, tt := range tests {
@@ -126,11 +113,6 @@ func TestVariantExtensionBadNestedTypes(t *testing.T) {
), Nullable: false})},
{"empty struct elem", arrow.StructOf(
arrow.Field{Name: "foobar", Type: arrow.StructOf(),
Nullable: false})},
- {"nullable value struct elem",
- arrow.StructOf(
- arrow.Field{Name: "foobar", Type:
arrow.StructOf(
- arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
- ), Nullable: false})},
{"non-nullable two elem struct", arrow.StructOf(
arrow.Field{Name: "foobar", Type: arrow.StructOf(
arrow.Field{Name: "value", Type:
arrow.BinaryTypes.Binary, Nullable: true},
diff --git a/parquet-testing b/parquet-testing
index 2dc8bf14..a3d96a65 160000
--- a/parquet-testing
+++ b/parquet-testing
@@ -1 +1 @@
-Subproject commit 2dc8bf140ed6e28652fc347211c7d661714c7f95
+Subproject commit a3d96a65e11e2bbca7d22a894e8313ede90a33a3
diff --git a/parquet/pqarrow/encode_arrow.go b/parquet/pqarrow/encode_arrow.go
index cdaba241..5724e9f8 100644
--- a/parquet/pqarrow/encode_arrow.go
+++ b/parquet/pqarrow/encode_arrow.go
@@ -333,7 +333,7 @@ func writeDenseArrow(ctx *arrowWriteContext, cw
file.ColumnChunkWriter, leafArr
case arrow.DECIMAL128:
for idx, val := range
leafArr.(*array.Decimal128).Values() {
debug.Assert(val.HighBits() == 0 ||
val.HighBits() == -1, "casting Decimal128 greater than the value range; high
bits must be 0 or -1")
- debug.Assert(val.LowBits() <=
math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32")
+ debug.Assert(int64(val.LowBits()) <=
math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32")
data[idx] = int32(val.LowBits())
}
case arrow.DECIMAL256:
diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go
index 2c0e70b5..7c56e333 100644
--- a/parquet/pqarrow/schema.go
+++ b/parquet/pqarrow/schema.go
@@ -242,7 +242,7 @@ func repFromNullable(isnullable bool) parquet.Repetition {
}
func variantToNode(t *extensions.VariantType, field arrow.Field, props
*parquet.WriterProperties, arrProps ArrowWriterProperties) (schema.Node, error)
{
- fields := make(schema.FieldList, 2, 3)
+ fields := make(schema.FieldList, 1, 3)
var err error
fields[0], err = fieldToNode("metadata", t.Metadata(), props, arrProps)
@@ -250,9 +250,12 @@ func variantToNode(t *extensions.VariantType, field
arrow.Field, props *parquet.
return nil, err
}
- fields[1], err = fieldToNode("value", t.Value(), props, arrProps)
- if err != nil {
- return nil, err
+ if value := t.Value(); value.Type != nil {
+ valueField, err := fieldToNode("value", value, props, arrProps)
+ if err != nil {
+ return nil, err
+ }
+ fields = append(fields, valueField)
}
if typed := t.TypedValue(); typed.Type != nil {
@@ -594,8 +597,9 @@ func getParquetType(typ arrow.DataType, props
*parquet.WriterProperties, arrprop
precision := int(dectype.GetPrecision())
scale := int(dectype.GetScale())
+ logicalType := schema.NewDecimalLogicalType(int32(precision),
int32(scale))
if !props.StoreDecimalAsInteger() || precision > 18 {
- return parquet.Types.FixedLenByteArray,
schema.NewDecimalLogicalType(int32(precision), int32(scale)),
int(DecimalSize(int32(precision))), nil
+ return parquet.Types.FixedLenByteArray, logicalType,
int(DecimalSize(int32(precision))), nil
}
pqType := parquet.Types.Int32
@@ -603,7 +607,7 @@ func getParquetType(typ arrow.DataType, props
*parquet.WriterProperties, arrprop
pqType = parquet.Types.Int64
}
- return pqType, schema.NoLogicalType{}, -1, nil
+ return pqType, logicalType, -1, nil
case arrow.DATE32:
return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil
case arrow.DATE64:
@@ -612,14 +616,14 @@ func getParquetType(typ arrow.DataType, props
*parquet.WriterProperties, arrprop
pqType, logicalType, err :=
getTimestampMeta(typ.(*arrow.TimestampType), props, arrprops)
return pqType, logicalType, -1, err
case arrow.TIME32:
- return parquet.Types.Int32, schema.NewTimeLogicalType(true,
schema.TimeUnitMillis), -1, nil
+ return parquet.Types.Int32, schema.NewTimeLogicalType(false,
schema.TimeUnitMillis), -1, nil
case arrow.TIME64:
pqTimeUnit := schema.TimeUnitMicros
if typ.(*arrow.Time64Type).Unit == arrow.Nanosecond {
pqTimeUnit = schema.TimeUnitNanos
}
- return parquet.Types.Int64, schema.NewTimeLogicalType(true,
pqTimeUnit), -1, nil
+ return parquet.Types.Int64, schema.NewTimeLogicalType(false,
pqTimeUnit), -1, nil
case arrow.FLOAT16:
return parquet.Types.FixedLenByteArray,
schema.Float16LogicalType{}, arrow.Float16SizeBytes, nil
case arrow.EXTENSION:
diff --git a/parquet/pqarrow/schema_test.go b/parquet/pqarrow/schema_test.go
index 6f3da880..6f5d14c7 100644
--- a/parquet/pqarrow/schema_test.go
+++ b/parquet/pqarrow/schema_test.go
@@ -184,11 +184,11 @@ func TestConvertArrowFlatPrimitives(t *testing.T) {
arrowFields = append(arrowFields, arrow.Field{Name: "date64", Type:
arrow.FixedWidthTypes.Date64, Nullable: false})
parquetFields = append(parquetFields,
schema.Must(schema.NewPrimitiveNodeLogical("time32",
parquet.Repetitions.Required,
- schema.NewTimeLogicalType(true, schema.TimeUnitMillis),
parquet.Types.Int32, 0, -1)))
+ schema.NewTimeLogicalType(false, schema.TimeUnitMillis),
parquet.Types.Int32, 0, -1)))
arrowFields = append(arrowFields, arrow.Field{Name: "time32", Type:
arrow.FixedWidthTypes.Time32ms, Nullable: false})
parquetFields = append(parquetFields,
schema.Must(schema.NewPrimitiveNodeLogical("time64",
parquet.Repetitions.Required,
- schema.NewTimeLogicalType(true, schema.TimeUnitMicros),
parquet.Types.Int64, 0, -1)))
+ schema.NewTimeLogicalType(false, schema.TimeUnitMicros),
parquet.Types.Int64, 0, -1)))
arrowFields = append(arrowFields, arrow.Field{Name: "time64", Type:
arrow.FixedWidthTypes.Time64us, Nullable: false})
parquetFields = append(parquetFields,
schema.NewInt96Node("timestamp96", parquet.Repetitions.Required, -1))
diff --git a/parquet/pqarrow/variant_test.go b/parquet/pqarrow/variant_test.go
new file mode 100644
index 00000000..81fa246b
--- /dev/null
+++ b/parquet/pqarrow/variant_test.go
@@ -0,0 +1,326 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow_test
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "iter"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "testing"
+ "unsafe"
+
+ "github.com/apache/arrow-go/v18/arrow"
+ "github.com/apache/arrow-go/v18/arrow/endian"
+ "github.com/apache/arrow-go/v18/arrow/extensions"
+ "github.com/apache/arrow-go/v18/arrow/memory"
+ "github.com/apache/arrow-go/v18/internal/json"
+ "github.com/apache/arrow-go/v18/parquet"
+ "github.com/apache/arrow-go/v18/parquet/pqarrow"
+ "github.com/apache/arrow-go/v18/parquet/variant"
+ "github.com/stretchr/testify/suite"
+)
+
+type ShreddedVariantTestSuite struct {
+ suite.Suite
+
+ generate bool
+
+ dirPrefix string
+ outDir string
+ cases []Case
+
+ errorCases []Case
+ singleVariant []Case
+ multiVariant []Case
+}
+
+func (s *ShreddedVariantTestSuite) SetupSuite() {
+ dir := os.Getenv("PARQUET_TEST_DATA")
+ if dir == "" {
+ s.T().Skip("PARQUET_TEST_DATA environment variable not set")
+ }
+
+ s.dirPrefix = filepath.Join(dir, "..", "shredded_variant")
+ s.outDir = filepath.Join(dir, "..", "go_variant")
+ if s.generate {
+ s.Require().NoError(os.MkdirAll(s.outDir, 0o755), "Failed to
create output directory: %s", s.outDir)
+ }
+
+ cases, err := os.Open(filepath.Join(s.dirPrefix, "cases.json"))
+ s.Require().NoError(err, "Failed to open cases.json")
+ defer cases.Close()
+
+ s.Require().NoError(json.NewDecoder(cases).Decode(&s.cases))
+
+ s.errorCases = slices.DeleteFunc(slices.Clone(s.cases), func(c Case)
bool {
+ return c.ErrorMessage == ""
+ })
+
+ s.singleVariant = slices.DeleteFunc(slices.Clone(s.cases), func(c Case)
bool {
+ return c.ErrorMessage != "" || c.VariantFile == "" ||
len(c.VariantFiles) > 0
+ })
+
+ s.multiVariant = slices.DeleteFunc(slices.Clone(s.cases), func(c Case)
bool {
+ return c.ErrorMessage != "" || c.VariantFile != "" ||
len(c.VariantFiles) == 0
+ })
+
+ if s.generate {
+ cases.Seek(0, io.SeekStart)
+ outCases, err := os.Create(filepath.Join(s.outDir,
"cases.json"))
+ s.Require().NoError(err, "Failed to create cases.json")
+ defer outCases.Close()
+
+ io.Copy(outCases, cases)
+ outCases.Sync()
+ }
+}
+
+type Case struct {
+ Number int `json:"case_number"`
+ Title string `json:"test"`
+ Notes string `json:"notes,omitempty"`
+ ParquetFile string `json:"parquet_file"`
+ VariantFile string `json:"variant_file,omitempty"`
+ VariantFiles []*string `json:"variant_files,omitempty"`
+ VariantData string `json:"variant,omitempty"`
+ Variants string `json:"variants,omitempty"`
+ ErrorMessage string `json:"error_message,omitempty"`
+}
+
+func readUnsigned(b []byte) (result uint32) {
+ v := (*[4]byte)(unsafe.Pointer(&result))
+ copy(v[:], b)
+ return endian.FromLE(result)
+}
+
+func (s *ShreddedVariantTestSuite) readVariant(filename string) variant.Value {
+ data, err := os.ReadFile(filename)
+ s.Require().NoError(err, "Failed to read variant file: %s", filename)
+
+ hdr := data[0]
+ offsetSize := int(1 + ((hdr & 0b11000000) >> 6))
+ dictSize := int(readUnsigned(data[1 : 1+offsetSize]))
+ offsetListOffset := 1 + offsetSize
+ dataOffset := offsetListOffset + ((1 + dictSize) * offsetSize)
+
+ idx := offsetListOffset + (offsetSize * dictSize)
+ endOffset := dataOffset + int(readUnsigned(data[idx:idx+offsetSize]))
+ val, err := variant.New(data[:endOffset], data[endOffset:])
+ s.Require().NoError(err, "Failed to create variant from data: %s",
filename)
+ return val
+}
+
+func (s *ShreddedVariantTestSuite) readParquet(filename string) arrow.Table {
+ file, err := os.Open(filepath.Join(s.dirPrefix, filename))
+ s.Require().NoError(err, "Failed to open Parquet file: %s", filename)
+ defer file.Close()
+
+ tbl, err := pqarrow.ReadTable(context.Background(), file, nil,
pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+ s.Require().NoError(err, "Failed to read Parquet file: %s", filename)
+ return tbl
+}
+
+func (s *ShreddedVariantTestSuite) writeVariantFile(filename string, val
variant.Value) {
+ out, err := os.Create(filepath.Join(s.outDir, filename))
+ s.Require().NoError(err)
+ defer out.Close()
+
+ _, err = out.Write(val.Metadata().Bytes())
+ s.Require().NoError(err)
+ _, err = out.Write(val.Bytes())
+ s.Require().NoError(err)
+}
+
+func (s *ShreddedVariantTestSuite) writeParquetFile(filename string, tbl
arrow.Table) {
+ out, err := os.Create(filepath.Join(s.outDir, filename))
+ s.Require().NoError(err)
+ defer out.Close()
+
+ s.Require().NoError(pqarrow.WriteTable(tbl, out, max(1, tbl.NumRows()),
parquet.NewWriterProperties(
+ parquet.WithDictionaryDefault(false), parquet.WithStats(false),
+ parquet.WithStoreDecimalAsInteger(true),
+ ), pqarrow.DefaultWriterProps()))
+}
+
+func zip[T, U any](a iter.Seq[T], b iter.Seq[U]) iter.Seq2[T, U] {
+ return func(yield func(T, U) bool) {
+ nexta, stopa := iter.Pull(a)
+ nextb, stopb := iter.Pull(b)
+ defer stopa()
+ defer stopb()
+
+ for {
+ a, ok := nexta()
+ if !ok {
+ return
+ }
+ b, ok := nextb()
+ if !ok {
+ return
+ }
+ if !yield(a, b) {
+ return
+ }
+ }
+ }
+}
+
+func (s *ShreddedVariantTestSuite) assertVariantEqual(expected, actual
variant.Value) {
+ switch expected.BasicType() {
+ case variant.BasicObject:
+ exp := expected.Value().(variant.ObjectValue)
+ act := actual.Value().(variant.ObjectValue)
+
+ s.Equal(exp.NumElements(), act.NumElements(), "Expected %d
elements in object, got %d", exp.NumElements(), act.NumElements())
+ for i := range exp.NumElements() {
+ expectedField, err := exp.FieldAt(i)
+ s.Require().NoError(err, "Failed to get expected field
at index %d", i)
+ actualField, err := act.FieldAt(i)
+ s.Require().NoError(err, "Failed to get actual field at
index %d", i)
+
+ s.Equal(expectedField.Key, actualField.Key, "Expected
field key %s, got %s", expectedField.Key, actualField.Key)
+ s.assertVariantEqual(expectedField.Value,
actualField.Value)
+ }
+ case variant.BasicArray:
+ exp := expected.Value().(variant.ArrayValue)
+ act := actual.Value().(variant.ArrayValue)
+
+ s.Equal(exp.Len(), act.Len(), "Expected array length %d, got
%d", exp.Len(), act.Len())
+ for e, a := range zip(exp.Values(), act.Values()) {
+ s.assertVariantEqual(e, a)
+ }
+ default:
+ switch expected.Type() {
+ case variant.Decimal4, variant.Decimal8, variant.Decimal16:
+ e, err := json.Marshal(expected.Value())
+ s.Require().NoError(err, "Failed to marshal expected
value")
+ a, err := json.Marshal(actual.Value())
+ s.Require().NoError(err, "Failed to marshal actual
value")
+ s.JSONEq(string(e), string(a), "Expected variant value
%s, got %s", e, a)
+ default:
+ s.EqualValues(expected.Value(), actual.Value(),
"Expected variant value %v, got %v", expected.Value(), actual.Value())
+ }
+ }
+}
+
+func (s *ShreddedVariantTestSuite) TestSingleVariantCases() {
+ for _, c := range s.singleVariant {
+ s.Run(c.Title, func() {
+ s.Run(fmt.Sprint(c.Number), func() {
+ if strings.Contains(c.ParquetFile, "-INVALID") {
+ s.T().Skip(c.Notes)
+ }
+
+ expected :=
s.readVariant(filepath.Join(s.dirPrefix, c.VariantFile))
+ if s.generate {
+ s.writeVariantFile(c.VariantFile,
expected)
+ }
+
+ tbl := s.readParquet(c.ParquetFile)
+ defer tbl.Release()
+
+ if s.generate {
+ s.writeParquetFile(c.ParquetFile, tbl)
+ }
+
+ col := tbl.Column(1).Data().Chunk(0)
+ s.Require().IsType(&extensions.VariantArray{},
col)
+
+ variantArray := col.(*extensions.VariantArray)
+ s.Require().Equal(1, variantArray.Len(),
"Expected single variant value")
+
+ val, err := variantArray.Value(0)
+ s.Require().NoError(err, "Failed to get variant
value from array")
+ s.assertVariantEqual(expected, val)
+ })
+ })
+ }
+}
+
+func (s *ShreddedVariantTestSuite) TestMultiVariantCases() {
+ for _, c := range s.multiVariant {
+ s.Run(c.Title, func() {
+ s.Run(fmt.Sprint(c.Number), func() {
+ tbl := s.readParquet(c.ParquetFile)
+ defer tbl.Release()
+
+ if s.generate {
+ s.writeParquetFile(c.ParquetFile, tbl)
+ }
+
+ s.Require().EqualValues(len(c.VariantFiles),
tbl.NumRows(), "Expected number of rows to match number of variant files")
+ col := tbl.Column(1).Data().Chunk(0)
+ s.Require().IsType(&extensions.VariantArray{},
col)
+
+ variantArray := col.(*extensions.VariantArray)
+ for i, variantFile := range c.VariantFiles {
+ if variantFile == nil {
+ s.True(variantArray.IsNull(i),
"Expected null value at index %d", i)
+ continue
+ }
+
+ expected :=
s.readVariant(filepath.Join(s.dirPrefix, *variantFile))
+ if s.generate {
+
s.writeVariantFile(*variantFile, expected)
+ }
+
+ actual, err := variantArray.Value(i)
+ s.Require().NoError(err, "Failed to get
variant value at index %d", i)
+ s.assertVariantEqual(expected, actual)
+ }
+ })
+ })
+ }
+}
+
+func (s *ShreddedVariantTestSuite) TestErrorCases() {
+ for _, c := range s.errorCases {
+ s.Run(c.Title, func() {
+ s.Run(fmt.Sprint(c.Number), func() {
+ switch c.Number {
+ case 127:
+ s.T().Skip("Skipping case 127: test
says uint32 should error, we just upcast to int64")
+ case 137:
+ s.T().Skip("Skipping case 137: test
says flba(4) should error, we just treat it as a binary variant")
+ }
+
+ tbl := s.readParquet(c.ParquetFile)
+ defer tbl.Release()
+
+ if s.generate {
+ s.writeParquetFile(c.ParquetFile, tbl)
+ }
+
+ col := tbl.Column(1).Data().Chunk(0)
+ s.Require().IsType(&extensions.VariantArray{},
col)
+
+ variantArray := col.(*extensions.VariantArray)
+ _, err := variantArray.Value(0)
+ s.Error(err, "Expected error for case %d: %s",
c.Number, c.ErrorMessage)
+ })
+ })
+ }
+}
+
+func TestShreddedVariantExamples(t *testing.T) {
+ suite.Run(t, &ShreddedVariantTestSuite{generate: false})
+}
diff --git a/parquet/schema/logical_types.go b/parquet/schema/logical_types.go
index 0c0ce559..e7f1c29f 100644
--- a/parquet/schema/logical_types.go
+++ b/parquet/schema/logical_types.go
@@ -24,6 +24,7 @@ import (
"github.com/apache/arrow-go/v18/parquet"
"github.com/apache/arrow-go/v18/parquet/internal/debug"
format "github.com/apache/arrow-go/v18/parquet/internal/gen-go/parquet"
+ "github.com/apache/thrift/lib/go/thrift"
)
// DecimalMetadata is a struct for managing scale and precision information
between
@@ -1139,7 +1140,7 @@ func (VariantLogicalType) IsCompatible(ct ConvertedType,
_ DecimalMetadata) bool
func (VariantLogicalType) IsApplicable(parquet.Type, int32) bool { return
false }
func (VariantLogicalType) toThrift() *format.LogicalType {
- return &format.LogicalType{VARIANT: format.NewVariantType()}
+ return &format.LogicalType{VARIANT:
&format.VariantType{SpecificationVersion: thrift.Int8Ptr(1)}}
}
func (VariantLogicalType) Equals(rhs LogicalType) bool {
diff --git a/parquet/variant/builder.go b/parquet/variant/builder.go
index 194814c6..68fc178d 100644
--- a/parquet/variant/builder.go
+++ b/parquet/variant/builder.go
@@ -887,7 +887,7 @@ func (b *Builder) Build() (Value, error) {
type variantPrimitiveType interface {
constraints.Integer | constraints.Float | string | []byte |
arrow.Date32 | arrow.Time64 | arrow.Timestamp | bool |
- uuid.UUID | DecimalValue[decimal.Decimal32] |
+ uuid.UUID | DecimalValue[decimal.Decimal32] | time.Time |
DecimalValue[decimal.Decimal64] |
DecimalValue[decimal.Decimal128]
}
@@ -895,17 +895,25 @@ type variantPrimitiveType interface {
// variant value. At the moment this is just delegating to the
[Builder.Append] method,
// but in the future it will be optimized to avoid the extra overhead and
reduce allocations.
func Encode[T variantPrimitiveType](v T, opt ...AppendOpt) ([]byte, error) {
+ out, err := Of(v, opt...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to encode variant value: %w",
err)
+ }
+ return out.value, nil
+}
+
+func Of[T variantPrimitiveType](v T, opt ...AppendOpt) (Value, error) {
var b Builder
if err := b.Append(v, opt...); err != nil {
- return nil, fmt.Errorf("failed to append value: %w", err)
+ return Value{}, fmt.Errorf("failed to append value: %w", err)
}
val, err := b.Build()
if err != nil {
- return nil, fmt.Errorf("failed to build variant value: %w", err)
+ return Value{}, fmt.Errorf("failed to build variant value: %w",
err)
}
- return val.value, nil
+ return val, nil
}
func ParseJSON(data string, allowDuplicateKeys bool) (Value, error) {
diff --git a/parquet/variant/builder_test.go b/parquet/variant/builder_test.go
index 09fa80eb..982fa4e9 100644
--- a/parquet/variant/builder_test.go
+++ b/parquet/variant/builder_test.go
@@ -57,9 +57,7 @@ func TestBuildPrimitive(t *testing.T) {
{"primitive_int8", func(b *variant.Builder) error { return
b.AppendInt(42) }},
{"primitive_int16", func(b *variant.Builder) error { return
b.AppendInt(1234) }},
{"primitive_int32", func(b *variant.Builder) error { return
b.AppendInt(123456) }},
- // FIXME: https://github.com/apache/parquet-testing/issues/82
- // primitive_int64 is an int32 value, but the metadata is int64
- {"primitive_int64", func(b *variant.Builder) error { return
b.AppendInt(12345678) }},
+ {"primitive_int64", func(b *variant.Builder) error { return
b.AppendInt(1234567890123456789) }},
{"primitive_float", func(b *variant.Builder) error { return
b.AppendFloat32(1234568000) }},
{"primitive_double", func(b *variant.Builder) error { return
b.AppendFloat64(1234567890.1234) }},
{"primitive_string", func(b *variant.Builder) error {
diff --git a/parquet/variant/variant.go b/parquet/variant/variant.go
index 800b7eb2..254bc3c3 100644
--- a/parquet/variant/variant.go
+++ b/parquet/variant/variant.go
@@ -650,7 +650,10 @@ func (v Value) Value() any {
}
case BasicShortString:
sz := int(v.value[0] >> 2)
- return unsafe.String(&v.value[1], sz)
+ if sz > 0 {
+ return unsafe.String(&v.value[1], sz)
+ }
+ return ""
case BasicObject:
valueHdr := (v.value[0] >> basicTypeBits)
fieldOffsetSz := (valueHdr & 0b11) + 1
diff --git a/parquet/variant/variant_test.go b/parquet/variant/variant_test.go
index 2ef4da38..c623f646 100644
--- a/parquet/variant/variant_test.go
+++ b/parquet/variant/variant_test.go
@@ -152,9 +152,7 @@ func TestPrimitiveVariants(t *testing.T) {
{"primitive_int8", int8(42), variant.Int8, "42"},
{"primitive_int16", int16(1234), variant.Int16, "1234"},
{"primitive_int32", int32(123456), variant.Int32, "123456"},
- // FIXME: https://github.com/apache/parquet-testing/issues/82
- // primitive_int64 is an int32 value, but the metadata is int64
- {"primitive_int64", int32(12345678), variant.Int32, "12345678"},
+ {"primitive_int64", int64(1234567890123456789), variant.Int64,
"1234567890123456789"},
{"primitive_float", float32(1234567940.0), variant.Float,
"1234568000"},
{"primitive_double", float64(1234567890.1234), variant.Double,
"1234567890.1234"},
{"primitive_string",