This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new fff99b3c fix(parquet/pqarrow): unsupported dictionary types in pqarrow 
(#520)
fff99b3c is described below

commit fff99b3c9aa884dcd2846a33b7335c9768e78f21
Author: Matt Topol <[email protected]>
AuthorDate: Wed Oct 1 12:16:50 2025 +0200

    fix(parquet/pqarrow): unsupported dictionary types in pqarrow (#520)
    
    ### Rationale for this change
    fixes #490
    
    ### What changes are included in this PR?
    When using `PutDictionary` to write an arrow dictionary array to Parquet
    directly, we will now normalize the data type to a parquet physical
    storage type to avoid type problems. e.g. A dictionary with `int8`
    values will get cast to `int32` values before we process it so that all
    of the remaining encoding code can continue to be based on the physical
    storage type. This alleviates panics when using `pqarrow` to write
    Dictionary arrays.
    
    ### Are these changes tested?
    Yes, the unit tests are updated to include tests for int8/uint8/etc. as
    dictionary value types.
    
    ### Are there any user-facing changes?
    No, just fixing panics.
---
 parquet/internal/encoding/byte_array_encoder.go    |  5 ++
 .../encoding/fixed_len_byte_array_encoder.go       |  5 ++
 parquet/internal/encoding/typed_encoder.go         | 30 +++++++++--
 parquet/internal/encoding/types.go                 |  6 +++
 parquet/pqarrow/encode_arrow_test.go               |  4 ++
 parquet/pqarrow/encode_dict_compute.go             | 20 +++++---
 parquet/pqarrow/encode_dictionary_test.go          | 59 +++++++++++++---------
 7 files changed, 93 insertions(+), 36 deletions(-)

diff --git a/parquet/internal/encoding/byte_array_encoder.go 
b/parquet/internal/encoding/byte_array_encoder.go
index 84323b12..d0439faf 100644
--- a/parquet/internal/encoding/byte_array_encoder.go
+++ b/parquet/internal/encoding/byte_array_encoder.go
@@ -122,6 +122,11 @@ func (enc *DictByteArrayEncoder) PutSpaced(in 
[]parquet.ByteArray, validBits []b
        })
 }
 
+func (enc *DictByteArrayEncoder) NormalizeDict(values arrow.Array) 
(arrow.Array, error) {
+       values.Retain()
+       return values, nil
+}
+
 // PutDictionary allows pre-seeding a dictionary encoder with
 // a dictionary from an Arrow Array.
 //
diff --git a/parquet/internal/encoding/fixed_len_byte_array_encoder.go 
b/parquet/internal/encoding/fixed_len_byte_array_encoder.go
index 56cf242b..0f4345d2 100644
--- a/parquet/internal/encoding/fixed_len_byte_array_encoder.go
+++ b/parquet/internal/encoding/fixed_len_byte_array_encoder.go
@@ -146,6 +146,11 @@ func (enc *DictFixedLenByteArrayEncoder) PutSpaced(in 
[]parquet.FixedLenByteArra
        })
 }
 
+func (enc *DictFixedLenByteArrayEncoder) NormalizeDict(values arrow.Array) 
(arrow.Array, error) {
+       values.Retain()
+       return values, nil
+}
+
 // PutDictionary allows pre-seeding a dictionary encoder with
 // a dictionary from an Arrow Array.
 //
diff --git a/parquet/internal/encoding/typed_encoder.go 
b/parquet/internal/encoding/typed_encoder.go
index bbc468b8..1cba18da 100644
--- a/parquet/internal/encoding/typed_encoder.go
+++ b/parquet/internal/encoding/typed_encoder.go
@@ -17,11 +17,13 @@
 package encoding
 
 import (
+       "context"
        "errors"
        "fmt"
        "unsafe"
 
        "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/compute"
        "github.com/apache/arrow-go/v18/arrow/memory"
        "github.com/apache/arrow-go/v18/internal/bitutils"
        shared_utils "github.com/apache/arrow-go/v18/internal/utils"
@@ -151,21 +153,41 @@ func (enc *typedDictEncoder[T]) PutSpaced(in []T, 
validBits []byte, validBitsOff
        })
 }
 
-type arrvalues[T int32 | int64 | float32 | float64] interface {
+type arrvalues[T arrow.ValueType] interface {
        arrow.TypedArray[T]
        Values() []T
 }
 
+func (enc *typedDictEncoder[T]) NormalizeDict(values arrow.Array) 
(arrow.Array, error) {
+       if _, ok := values.(arrvalues[T]); ok {
+               values.Retain()
+               return values, nil
+       }
+
+       ctx := compute.WithAllocator(context.Background(), enc.mem)
+       return compute.CastToType(ctx, values, arrow.GetDataType[T]())
+}
+
 func (enc *typedDictEncoder[T]) PutDictionary(values arrow.Array) error {
        if err := enc.canPutDictionary(values); err != nil {
                return err
        }
 
        enc.dictEncodedSize += values.Len() * int(unsafe.Sizeof(T(0)))
-       data := values.(arrvalues[T]).Values()
-
        typedMemo := enc.memo.(TypedMemoTable[T])
-       for _, val := range data {
+       data, ok := values.(arrvalues[T])
+       if !ok {
+               var err error
+               ctx := compute.WithAllocator(context.Background(), enc.mem)
+               values, err = compute.CastToType(ctx, values, 
arrow.GetDataType[T]())
+               if err != nil {
+                       return err
+               }
+               defer values.Release()
+               data = values.(arrvalues[T])
+       }
+
+       for _, val := range data.Values() {
                if _, _, err := typedMemo.InsertOrGet(val); err != nil {
                        return err
                }
diff --git a/parquet/internal/encoding/types.go 
b/parquet/internal/encoding/types.go
index 67f627d8..ccd3db2a 100644
--- a/parquet/internal/encoding/types.go
+++ b/parquet/internal/encoding/types.go
@@ -113,6 +113,12 @@ type DictEncoder interface {
        // of [0,dictSize) and is not validated. Returns an error if a 
non-integral
        // array is passed.
        PutIndices(arrow.Array) error
+       // NormalizeDict takes an arrow array and normalizes it to a parquet
+       // native type. e.g. a dictionary of type int8 will be cast to an int32
+       // dictionary for parquet storage.
+       //
+       // The returned array must always be released by the caller.
+       NormalizeDict(arrow.Array) (arrow.Array, error)
 }
 
 var bufferPool = sync.Pool{
diff --git a/parquet/pqarrow/encode_arrow_test.go 
b/parquet/pqarrow/encode_arrow_test.go
index a279f931..e9df2af9 100644
--- a/parquet/pqarrow/encode_arrow_test.go
+++ b/parquet/pqarrow/encode_arrow_test.go
@@ -1352,6 +1352,10 @@ var fullTypeList = []arrow.DataType{
 }
 
 var dictEncodingSupportedTypeList = []arrow.DataType{
+       arrow.PrimitiveTypes.Int8,
+       arrow.PrimitiveTypes.Uint8,
+       arrow.PrimitiveTypes.Int16,
+       arrow.PrimitiveTypes.Uint16,
        arrow.PrimitiveTypes.Int32,
        arrow.PrimitiveTypes.Int64,
        arrow.PrimitiveTypes.Float32,
diff --git a/parquet/pqarrow/encode_dict_compute.go 
b/parquet/pqarrow/encode_dict_compute.go
index 9f76cbfb..ab587503 100644
--- a/parquet/pqarrow/encode_dict_compute.go
+++ b/parquet/pqarrow/encode_dict_compute.go
@@ -90,12 +90,18 @@ func writeDictionaryArrow(ctx *arrowWriteContext, cw 
file.ColumnChunkWriter, lea
                pageStats   = cw.PageStatistics()
        )
 
+       normalized, err := dictEncoder.NormalizeDict(dict)
+       if err != nil {
+               return err
+       }
+       defer normalized.Release()
+
        updateStats := func() error {
                var referencedDict arrow.Array
 
                ctx := compute.WithAllocator(context.Background(), 
ctx.props.mem)
                // if dictionary is the same dictionary we already have, just 
use that
-               if preserved != nil && preserved == dict {
+               if preserved != nil && preserved == normalized {
                        referencedDict = preserved
                } else {
                        referencedIndices, err := compute.UniqueArray(ctx, 
indices)
@@ -104,10 +110,10 @@ func writeDictionaryArrow(ctx *arrowWriteContext, cw 
file.ColumnChunkWriter, lea
                        }
 
                        // on first run, we might be able to re-use the 
existing dict
-                       if referencedIndices.Len() == dict.Len() {
-                               referencedDict = dict
+                       if referencedIndices.Len() == normalized.Len() {
+                               referencedDict = normalized
                        } else {
-                               referencedDict, err = 
compute.TakeArrayOpts(ctx, dict, referencedIndices, 
compute.TakeOptions{BoundsCheck: false})
+                               referencedDict, err = 
compute.TakeArrayOpts(ctx, normalized, referencedIndices, 
compute.TakeOptions{BoundsCheck: false})
                                if err != nil {
                                        return err
                                }
@@ -126,7 +132,7 @@ func writeDictionaryArrow(ctx *arrowWriteContext, cw 
file.ColumnChunkWriter, lea
 
        switch {
        case preserved == nil:
-               if err := dictEncoder.PutDictionary(dict); err != nil {
+               if err := dictEncoder.PutDictionary(normalized); err != nil {
                        return err
                }
 
@@ -134,7 +140,7 @@ func writeDictionaryArrow(ctx *arrowWriteContext, cw 
file.ColumnChunkWriter, lea
                // memo table will be out of sync with the indices in the arrow 
array
                // the easiest solution for this uncommon case is to fallback 
to plain
                // encoding
-               if dictEncoder.NumEntries() != dict.Len() {
+               if dictEncoder.NumEntries() != normalized.Len() {
                        cw.FallbackToPlain()
                        return writeDense()
                }
@@ -145,7 +151,7 @@ func writeDictionaryArrow(ctx *arrowWriteContext, cw 
file.ColumnChunkWriter, lea
                        }
                }
 
-       case !array.Equal(dict, preserved):
+       case !array.Equal(normalized, preserved):
                // dictionary has changed
                cw.FallbackToPlain()
                return writeDense()
diff --git a/parquet/pqarrow/encode_dictionary_test.go 
b/parquet/pqarrow/encode_dictionary_test.go
index e47b7e6b..fecdc649 100644
--- a/parquet/pqarrow/encode_dictionary_test.go
+++ b/parquet/pqarrow/encode_dictionary_test.go
@@ -68,45 +68,54 @@ func (ps *ParquetIOTestSuite) 
TestSingleColumnOptionalDictionaryWrite() {
 }
 
 func (ps *ParquetIOTestSuite) TestSingleColumnRequiredDictionaryWrite() {
+       idxTypes := []arrow.DataType{
+               arrow.PrimitiveTypes.Int8, arrow.PrimitiveTypes.Int16,
+               arrow.PrimitiveTypes.Uint8, arrow.PrimitiveTypes.Uint16,
+               arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int64,
+               arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Uint64}
+
        for _, dt := range dictEncodingSupportedTypeList {
                // skip tests for bool as we don't do dictionaries for it
                if dt.ID() == arrow.BOOL {
                        continue
                }
 
-               ps.Run(dt.Name(), func() {
-                       mem := 
memory.NewCheckedAllocator(memory.DefaultAllocator)
-                       defer mem.AssertSize(ps.T(), 0)
+               for _, idxtype := range idxTypes {
 
-                       bldr := array.NewDictionaryBuilder(mem, 
&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int16, ValueType: dt})
-                       defer bldr.Release()
+                       ps.Run(dt.Name()+" key="+idxtype.Name(), func() {
+                               mem := 
memory.NewCheckedAllocator(memory.DefaultAllocator)
+                               defer mem.AssertSize(ps.T(), 0)
 
-                       values := testutils.RandomNonNull(mem, dt, smallSize)
-                       defer values.Release()
-                       ps.Require().NoError(bldr.AppendArray(values))
+                               bldr := array.NewDictionaryBuilder(mem, 
&arrow.DictionaryType{IndexType: idxtype, ValueType: dt})
+                               defer bldr.Release()
 
-                       arr := bldr.NewDictionaryArray()
-                       defer arr.Release()
+                               values := testutils.RandomNonNull(mem, dt, 
smallSize)
+                               defer values.Release()
+                               ps.Require().NoError(bldr.AppendArray(values))
 
-                       sc := ps.makeSimpleSchema(arr.DataType(), 
parquet.Repetitions.Required)
-                       data := ps.writeDictionaryColumn(mem, sc, arr)
+                               arr := bldr.NewDictionaryArray()
+                               defer arr.Release()
 
-                       rdr, err := file.NewParquetReader(bytes.NewReader(data))
-                       ps.NoError(err)
-                       defer rdr.Close()
+                               sc := ps.makeSimpleSchema(arr.DataType(), 
parquet.Repetitions.Required)
+                               data := ps.writeDictionaryColumn(mem, sc, arr)
 
-                       metadata := rdr.MetaData()
-                       ps.Len(metadata.RowGroups, 1)
+                               rdr, err := 
file.NewParquetReader(bytes.NewReader(data))
+                               ps.NoError(err)
+                               defer rdr.Close()
 
-                       rg := metadata.RowGroup(0)
-                       col, err := rg.ColumnChunk(0)
-                       ps.NoError(err)
+                               metadata := rdr.MetaData()
+                               ps.Len(metadata.RowGroups, 1)
 
-                       stats, err := col.Statistics()
-                       ps.NoError(err)
-                       ps.EqualValues(smallSize, stats.NumValues())
-                       ps.EqualValues(0, stats.NullCount())
-               })
+                               rg := metadata.RowGroup(0)
+                               col, err := rg.ColumnChunk(0)
+                               ps.NoError(err)
+
+                               stats, err := col.Statistics()
+                               ps.NoError(err)
+                               ps.EqualValues(smallSize, stats.NumValues())
+                               ps.EqualValues(0, stats.NullCount())
+                       })
+               }
        }
 }
 

Reply via email to