This is an automated email from the ASF dual-hosted git repository. zeroshade pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push: new 42c448e feat(arrow/compute): implement "is_in" function (#319) 42c448e is described below commit 42c448ed72e1a122897f583f72f2fd223a34ec75 Author: Matt Topol <zotthewiz...@gmail.com> AuthorDate: Mon Mar 24 08:45:33 2025 -0400 feat(arrow/compute): implement "is_in" function (#319) ### Rationale for this change Since we use arrow-go for iceberg-go and utilize the compute libraries for filtering, we need to ensure we add support for the minimal number of functions that iceberg requires. ### What changes are included in this PR? Implementing the `is_in` function for the function registry, registering it by default, and ensuring we also allow using `is_in` from substrait. ### Are these changes tested? Yes, unit tests are included. ### Are there any user-facing changes? There shouldn't be any user-facing changes. --- arrow/array/builder.go | 6 +- arrow/array/float16.go | 12 - arrow/compute/arithmetic_test.go | 2 +- arrow/compute/exec/kernel.go | 10 + arrow/compute/executor.go | 7 +- arrow/compute/expression.go | 1 + arrow/compute/exprs/exec.go | 12 +- arrow/compute/exprs/extension_types.go | 10 +- arrow/compute/exprs/types.go | 52 +- .../compute/internal/kernels/scalar_set_lookup.go | 297 ++++++++++ arrow/compute/internal/kernels/vector_hash.go | 6 +- arrow/compute/internal/kernels/vector_selection.go | 10 +- arrow/compute/registry.go | 1 + arrow/compute/scalar_set_lookup.go | 222 ++++++++ arrow/compute/scalar_set_lookup_test.go | 606 +++++++++++++++++++++ arrow/scalar/append.go | 107 ++-- internal/hashing/xxh3_memo_table.gen.go | 146 +++-- internal/hashing/xxh3_memo_table.gen.go.tmpl | 25 +- internal/hashing/xxh3_memo_table.go | 40 +- parquet/file/file_writer.go | 5 +- parquet/pqarrow/file_writer.go | 4 + parquet/pqarrow/schema.go | 12 +- 22 files changed, 1441 insertions(+), 152 deletions(-) diff --git a/arrow/array/builder.go b/arrow/array/builder.go index a2a40d4..f5b5ee4 100644 --- a/arrow/array/builder.go +++ b/arrow/array/builder.go @@ -176,13 +176,13 @@ func (b *builder) resize(newBits int, init func(int)) { } func (b *builder) reserve(elements int, resize func(int)) { - if b.nullBitmap == nil { - b.nullBitmap = memory.NewResizableBuffer(b.mem) - } if b.length+elements > b.capacity { newCap := bitutil.NextPowerOf2(b.length + elements) resize(newCap) } + if b.nullBitmap == nil { + b.nullBitmap = memory.NewResizableBuffer(b.mem) + } } // unsafeAppendBoolsToBitmap appends the contents of valid to the validity bitmap. diff --git a/arrow/array/float16.go b/arrow/array/float16.go index c98362d..a472cfa 100644 --- a/arrow/array/float16.go +++ b/arrow/array/float16.go @@ -106,18 +106,6 @@ func (a *Float16) MarshalJSON() ([]byte, error) { return json.Marshal(vals) } -func arrayEqualFloat16(left, right *Float16) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - var ( _ arrow.Array = (*Float16)(nil) _ arrow.TypedArray[float16.Num] = (*Float16)(nil) diff --git a/arrow/compute/arithmetic_test.go b/arrow/compute/arithmetic_test.go index 6db0212..07fb1fc 100644 --- a/arrow/compute/arithmetic_test.go +++ b/arrow/compute/arithmetic_test.go @@ -204,7 +204,7 @@ type BinaryArithmeticSuite[T arrow.NumericType] struct { scalarEqualOpts []scalar.EqualOption } -func (BinaryArithmeticSuite[T]) DataType() arrow.DataType { +func (*BinaryArithmeticSuite[T]) DataType() arrow.DataType { return arrow.GetDataType[T]() } diff --git a/arrow/compute/exec/kernel.go b/arrow/compute/exec/kernel.go index d7de176..fd3a52d 100644 --- a/arrow/compute/exec/kernel.go +++ b/arrow/compute/exec/kernel.go @@ -68,6 +68,7 @@ type NonAggKernel interface { GetNullHandling() NullHandling GetMemAlloc() MemAlloc CanFillSlices() bool + Cleanup() error } // KernelCtx is a small struct holding the context for a kernel execution @@ -604,6 +605,7 @@ type ScalarKernel struct { CanWriteIntoSlices bool NullHandling NullHandling MemAlloc MemAlloc + CleanupFn func(KernelState) error } // NewScalarKernel constructs a new kernel for scalar execution, constructing @@ -629,6 +631,13 @@ func NewScalarKernelWithSig(sig *KernelSignature, exec ArrayKernelExec, init Ker } } +func (s *ScalarKernel) Cleanup() error { + if s.CleanupFn != nil { + return s.CleanupFn(s.Data) + } + return nil +} + func (s *ScalarKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error { return s.ExecFn(ctx, sp, out) } @@ -693,3 +702,4 @@ func (s *VectorKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error func (s VectorKernel) GetNullHandling() NullHandling { return s.NullHandling } func (s VectorKernel) GetMemAlloc() MemAlloc { return s.MemAlloc } func (s VectorKernel) CanFillSlices() bool { return s.CanWriteIntoSlices } +func (s VectorKernel) Cleanup() error { return nil } diff --git a/arrow/compute/executor.go b/arrow/compute/executor.go index 54c65ad..bf41036 100644 --- a/arrow/compute/executor.go +++ b/arrow/compute/executor.go @@ -20,6 +20,7 @@ package compute import ( "context" + "errors" "fmt" "math" "runtime" @@ -579,6 +580,10 @@ func (s *scalarExecutor) WrapResults(ctx context.Context, out <-chan Datum, hasC } func (s *scalarExecutor) executeSpans(data chan<- Datum) (err error) { + defer func() { + err = errors.Join(err, s.kernel.Cleanup()) + }() + var ( input exec.ExecSpan output exec.ExecResult @@ -645,7 +650,7 @@ func (s *scalarExecutor) executeSingleSpan(input *exec.ExecSpan, out *exec.ExecR return s.kernel.Exec(s.ctx, input, out) } -func (s *scalarExecutor) setupPrealloc(totalLen int64, args []Datum) error { +func (s *scalarExecutor) setupPrealloc(_ int64, args []Datum) error { s.numOutBuf = len(s.outType.Layout().Buffers) outTypeID := s.outType.ID() // default to no validity pre-allocation for the following cases: diff --git a/arrow/compute/expression.go b/arrow/compute/expression.go index 88e1dde..4e60d38 100644 --- a/arrow/compute/expression.go +++ b/arrow/compute/expression.go @@ -490,6 +490,7 @@ func Cast(ex Expression, dt arrow.DataType) Expression { return NewCall("cast", []Expression{ex}, opts) } +// Deprecated: Use SetOptions instead type SetLookupOptions struct { ValueSet Datum `compute:"value_set"` SkipNulls bool `compute:"skip_nulls"` diff --git a/arrow/compute/exprs/exec.go b/arrow/compute/exprs/exec.go index 2e64381..0d0a139 100644 --- a/arrow/compute/exprs/exec.go +++ b/arrow/compute/exprs/exec.go @@ -524,7 +524,6 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E err error allScalar = true args = make([]compute.Datum, e.NArgs()) - argTypes = make([]arrow.DataType, e.NArgs()) ) for i := 0; i < e.NArgs(); i++ { switch v := e.Arg(i).(type) { @@ -543,20 +542,23 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E default: return nil, arrow.ErrNotImplemented } - - argTypes[i] = args[i].(compute.ArrayLikeDatum).Type() } _, conv, ok := ext.DecodeFunction(e.FuncRef()) if !ok { - return nil, arrow.ErrNotImplemented + return nil, fmt.Errorf("%w: %s", arrow.ErrNotImplemented, e.Name()) } - fname, opts, err := conv(e) + fname, args, opts, err := conv(e, args) if err != nil { return nil, err } + argTypes := make([]arrow.DataType, len(args)) + for i, arg := range args { + argTypes[i] = arg.(compute.ArrayLikeDatum).Type() + } + ectx := compute.GetExecCtx(ctx) fn, ok := ectx.Registry.GetFunction(fname) if !ok { diff --git a/arrow/compute/exprs/extension_types.go b/arrow/compute/exprs/extension_types.go index db780cb..f44da19 100644 --- a/arrow/compute/exprs/extension_types.go +++ b/arrow/compute/exprs/extension_types.go @@ -26,6 +26,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/extensions" ) type simpleExtensionTypeFactory[P comparable] struct { @@ -95,13 +96,6 @@ type simpleExtensionArrayFactory[P comparable] struct { array.ExtensionArrayBase } -type uuidExtParams struct{} - -var uuidType = simpleExtensionTypeFactory[uuidExtParams]{ - name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType { - return &arrow.FixedSizeBinaryType{ByteWidth: 16} - }} - type fixedCharExtensionParams struct { Length int32 `json:"length"` } @@ -138,7 +132,7 @@ var intervalDayType = simpleExtensionTypeFactory[intervalDayExtensionParams]{ }, } -func uuid() arrow.DataType { return uuidType.CreateType(uuidExtParams{}) } +func uuid() arrow.DataType { return extensions.NewUUIDType() } func fixedChar(length int32) arrow.DataType { return fixedCharType.CreateType(fixedCharExtensionParams{Length: length}) } diff --git a/arrow/compute/exprs/types.go b/arrow/compute/exprs/types.go index f48a6c5..0c468f3 100644 --- a/arrow/compute/exprs/types.go +++ b/arrow/compute/exprs/types.go @@ -26,6 +26,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/compute" + "github.com/apache/arrow-go/v18/arrow/scalar" "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/extensions" "github.com/substrait-io/substrait-go/v3/types" @@ -41,7 +42,8 @@ const ( SubstraitComparisonFuncsURI = SubstraitDefaultURIPrefix + "functions_comparison.yaml" SubstraitBooleanFuncsURI = SubstraitDefaultURIPrefix + "functions_boolean.yaml" - TimestampTzTimezone = "UTC" + SubstraitIcebergSetFuncURI = "https://github.com/apache/iceberg-go/blob/main/table/substrait/functions_set.yaml" + TimestampTzTimezone = "UTC" ) var hashSeed maphash.Seed @@ -127,6 +129,15 @@ func init() { panic(err) } } + + for _, fn := range []string{"is_in"} { + err := DefaultExtensionIDRegistry.AddSubstraitScalarToArrow( + extensions.ID{URI: SubstraitIcebergSetFuncURI, Name: fn}, + setLookupFuncSubstraitToArrowFunc) + if err != nil { + panic(err) + } + } } type overflowBehavior string @@ -178,7 +189,7 @@ func parseOption[typ ~string](sf *expr.ScalarFunction, optionName string, parser return def, arrow.ErrNotImplemented } -type substraitToArrow = func(*expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) +type substraitToArrow = func(*expr.ScalarFunction, []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) type arrowToSubstrait = func(fname string) (extensions.ID, []*types.FunctionOption, error) var substraitToArrowFuncMap = map[string]string{ @@ -199,7 +210,32 @@ var arrowToSubstraitFuncMap = map[string]string{ "or_kleene": "or", } -func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) { +func setLookupFuncSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) { + fname, _, _ = strings.Cut(sf.Name(), ":") + f, ok := substraitToArrowFuncMap[fname] + if ok { + fname = f + } + + setopts := &compute.SetOptions{ + NullBehavior: compute.NullMatchingMatch, + } + switch input[1].Kind() { + case compute.KindArray, compute.KindChunked: + setopts.ValueSet = input[1] + case compute.KindScalar: + // should be a list scalar + setopts.ValueSet = compute.NewDatumWithoutOwning( + input[1].(*compute.ScalarDatum).Value.(*scalar.List).Value) + } + + args, opts = input[0:1], setopts + return +} + +func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) { + args = input + fname, _, _ = strings.Cut(sf.Name(), ":") f, ok := substraitToArrowFuncMap[fname] if ok { @@ -219,19 +255,19 @@ func simpleMapArrowToSubstraitFunc(uri string) arrowToSubstrait { } func decodeOptionlessOverflowableArithmetic(n string) substraitToArrow { - return func(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) { + return func(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) { overflow, err := parseOption(sf, "overflow", &overflowParser, []overflowBehavior{overflowSILENT, overflowERROR}, overflowSILENT) if err != nil { - return n, nil, err + return n, input, nil, err } switch overflow { case overflowSILENT: - return n + "_unchecked", nil, nil + return n + "_unchecked", input, nil, nil case overflowERROR: - return n, nil, nil + return n, input, nil, nil default: - return n, nil, arrow.ErrNotImplemented + return n, input, nil, arrow.ErrNotImplemented } } } diff --git a/arrow/compute/internal/kernels/scalar_set_lookup.go b/arrow/compute/internal/kernels/scalar_set_lookup.go new file mode 100644 index 0000000..c356267 --- /dev/null +++ b/arrow/compute/internal/kernels/scalar_set_lookup.go @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernels + +import ( + "fmt" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/compute/exec" + "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/bitutils" + "github.com/apache/arrow-go/v18/internal/hashing" +) + +type NullMatchingBehavior int8 + +const ( + NullMatchingMatch NullMatchingBehavior = iota + NullMatchingSkip + NullMatchingEmitNull + NullMatchingInconclusive +) + +func visitBinary[OffsetT int32 | int64](data *exec.ArraySpan, valid func([]byte) error, null func() error) error { + if data.Len == 0 { + return nil + } + + rawBytes := data.Buffers[2].Buf + offsets := exec.GetSpanOffsets[OffsetT](data, 1) + return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len, + func(pos int64) error { + return valid(rawBytes[offsets[pos]:offsets[pos+1]]) + }, null) +} + +func visitNumeric[T arrow.FixedWidthType](data *exec.ArraySpan, valid func(T) error, null func() error) error { + if data.Len == 0 { + return nil + } + + values := exec.GetSpanValues[T](data, 1) + return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len, + func(pos int64) error { + return valid(values[pos]) + }, null) +} + +func visitFSB(data *exec.ArraySpan, valid func([]byte) error, null func() error) error { + if data.Len == 0 { + return nil + } + + sz := int64(data.Type.(arrow.FixedWidthDataType).Bytes()) + rawBytes := data.Buffers[1].Buf + + return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len, + func(pos int64) error { + return valid(rawBytes[pos*sz : (pos+1)*sz]) + }, null) +} + +type SetLookupOptions struct { + ValueSetType arrow.DataType + TotalLen int64 + ValueSet []exec.ArraySpan + NullBehavior NullMatchingBehavior +} + +type lookupState interface { + Init(SetLookupOptions) error +} + +func CreateSetLookupState(opts SetLookupOptions, alloc memory.Allocator) (exec.KernelState, error) { + valueSetType := opts.ValueSetType + if valueSetType.ID() == arrow.EXTENSION { + valueSetType = valueSetType.(arrow.ExtensionType).StorageType() + } + + var state lookupState + switch ty := valueSetType.(type) { + case arrow.BinaryDataType: + switch ty.Layout().Buffers[1].ByteWidth { + case 4: + state = &SetLookupState[[]byte]{ + Alloc: alloc, + visitFn: visitBinary[int32], + } + case 8: + state = &SetLookupState[[]byte]{ + Alloc: alloc, + visitFn: visitBinary[int64], + } + } + case arrow.FixedWidthDataType: + switch ty.Bytes() { + case 1: + state = &SetLookupState[uint8]{ + Alloc: alloc, + visitFn: visitNumeric[uint8], + } + case 2: + state = &SetLookupState[uint16]{ + Alloc: alloc, + visitFn: visitNumeric[uint16], + } + case 4: + state = &SetLookupState[uint32]{ + Alloc: alloc, + visitFn: visitNumeric[uint32], + } + case 8: + state = &SetLookupState[uint64]{ + Alloc: alloc, + visitFn: visitNumeric[uint64], + } + default: + state = &SetLookupState[[]byte]{ + Alloc: alloc, + visitFn: visitFSB, + } + } + + default: + return nil, fmt.Errorf("%w: unsupported type %s for SetLookup functions", arrow.ErrInvalid, opts.ValueSetType) + } + + return state, state.Init(opts) +} + +type SetLookupState[T hashing.MemoTypes] struct { + visitFn func(*exec.ArraySpan, func(T) error, func() error) error + ValueSetType arrow.DataType + Alloc memory.Allocator + Lookup hashing.TypedMemoTable[T] + // When there are duplicates in value set, memotable indices + // must be mapped back to indices in the value set + MemoIndexToValueIndex []int32 + NullIndex int32 + NullBehavior NullMatchingBehavior +} + +func (s *SetLookupState[T]) ValueType() arrow.DataType { + return s.ValueSetType +} + +func (s *SetLookupState[T]) Init(opts SetLookupOptions) error { + s.ValueSetType = opts.ValueSetType + s.NullBehavior = opts.NullBehavior + s.MemoIndexToValueIndex = make([]int32, 0, opts.TotalLen) + s.NullIndex = -1 + memoType := s.ValueSetType.ID() + if memoType == arrow.EXTENSION { + memoType = s.ValueSetType.(arrow.ExtensionType).StorageType().ID() + } + lookup, err := newMemoTable(s.Alloc, memoType) + if err != nil { + return err + } + s.Lookup = lookup.(hashing.TypedMemoTable[T]) + if s.Lookup == nil { + return fmt.Errorf("unsupported type %s for SetLookup functions", s.ValueSetType) + } + + var offset int64 + for _, c := range opts.ValueSet { + if err := s.AddArrayValueSet(&c, offset); err != nil { + return err + } + offset += c.Len + } + + lookupNull, _ := s.Lookup.GetNull() + if s.NullBehavior != NullMatchingSkip && lookupNull >= 0 { + s.NullIndex = int32(lookupNull) + } + return nil +} + +func (s *SetLookupState[T]) AddArrayValueSet(data *exec.ArraySpan, startIdx int64) error { + idx := startIdx + return s.visitFn(data, + func(v T) error { + memoSize := len(s.MemoIndexToValueIndex) + memoIdx, found, err := s.Lookup.InsertOrGet(v) + if err != nil { + return err + } + + if !found { + debug.Assert(memoIdx == memoSize, "inconsistent memo index and size") + s.MemoIndexToValueIndex = append(s.MemoIndexToValueIndex, int32(idx)) + } else { + debug.Assert(memoIdx < memoSize, "inconsistent memo index and size") + } + + idx++ + return nil + }, func() error { + memoSize := len(s.MemoIndexToValueIndex) + nullIdx, found := s.Lookup.GetOrInsertNull() + if !found { + debug.Assert(nullIdx == memoSize, "inconsistent memo index and size") + s.MemoIndexToValueIndex = append(s.MemoIndexToValueIndex, int32(idx)) + } else { + debug.Assert(nullIdx < memoSize, "inconsistent memo index and size") + } + + idx++ + return nil + }) +} + +func DispatchIsIn(state lookupState, in *exec.ArraySpan, out *exec.ExecResult) error { + inType := in.Type + if inType.ID() == arrow.EXTENSION { + inType = inType.(arrow.ExtensionType).StorageType() + } + + switch ty := inType.(type) { + case arrow.BinaryDataType: + return isInKernelExec(state.(*SetLookupState[[]byte]), in, out) + case arrow.FixedWidthDataType: + switch ty.Bytes() { + case 1: + return isInKernelExec(state.(*SetLookupState[uint8]), in, out) + case 2: + return isInKernelExec(state.(*SetLookupState[uint16]), in, out) + case 4: + return isInKernelExec(state.(*SetLookupState[uint32]), in, out) + case 8: + return isInKernelExec(state.(*SetLookupState[uint64]), in, out) + default: + return isInKernelExec(state.(*SetLookupState[[]byte]), in, out) + } + default: + return fmt.Errorf("%w: unsupported type %s for is_in function", arrow.ErrInvalid, in.Type) + } +} + +func isInKernelExec[T hashing.MemoTypes](state *SetLookupState[T], in *exec.ArraySpan, out *exec.ExecResult) error { + writerBool := bitutil.NewBitmapWriter(out.Buffers[1].Buf, int(out.Offset), int(out.Len)) + defer writerBool.Finish() + writerNulls := bitutil.NewBitmapWriter(out.Buffers[0].Buf, int(out.Offset), int(out.Len)) + defer writerNulls.Finish() + valueSetHasNull := state.NullIndex != -1 + return state.visitFn(in, + func(v T) error { + switch { + case state.Lookup.Exists(v): + writerBool.Set() + writerNulls.Set() + case state.NullBehavior == NullMatchingInconclusive && valueSetHasNull: + writerBool.Clear() + writerNulls.Clear() + default: + writerBool.Clear() + writerNulls.Set() + } + + writerBool.Next() + writerNulls.Next() + return nil + }, func() error { + switch { + case state.NullBehavior == NullMatchingMatch && valueSetHasNull: + writerBool.Set() + writerNulls.Set() + case state.NullBehavior == NullMatchingSkip || (!valueSetHasNull && state.NullBehavior == NullMatchingMatch): + writerBool.Clear() + writerNulls.Set() + default: + writerBool.Clear() + writerNulls.Clear() + } + + writerBool.Next() + writerNulls.Next() + return nil + }) +} diff --git a/arrow/compute/internal/kernels/vector_hash.go b/arrow/compute/internal/kernels/vector_hash.go index 51968f7..bb0c561 100644 --- a/arrow/compute/internal/kernels/vector_hash.go +++ b/arrow/compute/internal/kernels/vector_hash.go @@ -345,10 +345,10 @@ func newMemoTable(mem memory.Allocator, dt arrow.Type) (hashing.MemoTable, error return hashing.NewUint8MemoTable(0), nil case arrow.INT16, arrow.UINT16: return hashing.NewUint16MemoTable(0), nil - case arrow.INT32, arrow.UINT32, arrow.FLOAT32, + case arrow.INT32, arrow.UINT32, arrow.FLOAT32, arrow.DECIMAL32, arrow.DATE32, arrow.TIME32, arrow.INTERVAL_MONTHS: return hashing.NewUint32MemoTable(0), nil - case arrow.INT64, arrow.UINT64, arrow.FLOAT64, + case arrow.INT64, arrow.UINT64, arrow.FLOAT64, arrow.DECIMAL64, arrow.DATE64, arrow.TIME64, arrow.TIMESTAMP, arrow.DURATION, arrow.INTERVAL_DAY_TIME: return hashing.NewUint64MemoTable(0), nil @@ -481,7 +481,7 @@ func uniqueFinalize(ctx *exec.KernelCtx, results []*exec.ArraySpan) ([]*exec.Arr return []*exec.ArraySpan{&out}, nil } -func ensureHashDictionary(ctx *exec.KernelCtx, hash *dictionaryHashState) (*exec.ArraySpan, error) { +func ensureHashDictionary(_ *exec.KernelCtx, hash *dictionaryHashState) (*exec.ArraySpan, error) { out := &exec.ArraySpan{} if hash.dictionary != nil { diff --git a/arrow/compute/internal/kernels/vector_selection.go b/arrow/compute/internal/kernels/vector_selection.go index 4a61940..9bbc863 100644 --- a/arrow/compute/internal/kernels/vector_selection.go +++ b/arrow/compute/internal/kernels/vector_selection.go @@ -906,13 +906,13 @@ func takeIdxDispatch[ValT arrow.IntType](values, indices *exec.ArraySpan, out *e switch indices.Type.(arrow.FixedWidthDataType).Bytes() { case 1: - primitiveTakeImpl[uint8, ValT](getter, indices, out) + primitiveTakeImpl[uint8](getter, indices, out) case 2: - primitiveTakeImpl[uint16, ValT](getter, indices, out) + primitiveTakeImpl[uint16](getter, indices, out) case 4: - primitiveTakeImpl[uint32, ValT](getter, indices, out) + primitiveTakeImpl[uint32](getter, indices, out) case 8: - primitiveTakeImpl[uint64, ValT](getter, indices, out) + primitiveTakeImpl[uint64](getter, indices, out) default: return fmt.Errorf("%w: invalid indices byte width", arrow.ErrIndex) } @@ -1147,7 +1147,7 @@ func filterExec(ctx *exec.KernelCtx, outputLen int64, values, selection *exec.Ar return nil } -func binaryFilterNonNull[OffsetT int32 | int64](ctx *exec.KernelCtx, values, filter *exec.ArraySpan, outputLen int64, nullSelection NullSelectionBehavior, out *exec.ExecResult) error { +func binaryFilterNonNull[OffsetT int32 | int64](ctx *exec.KernelCtx, values, filter *exec.ArraySpan, outputLen int64, _ NullSelectionBehavior, out *exec.ExecResult) error { var ( offsetBuilder = newBufferBuilder[OffsetT](exec.GetAllocator(ctx.Ctx)) dataBuilder = newBufferBuilder[uint8](exec.GetAllocator(ctx.Ctx)) diff --git a/arrow/compute/registry.go b/arrow/compute/registry.go index 12bc0b8..6b9250c 100644 --- a/arrow/compute/registry.go +++ b/arrow/compute/registry.go @@ -53,6 +53,7 @@ func GetFunctionRegistry() FunctionRegistry { RegisterScalarComparisons(registry) RegisterVectorHash(registry) RegisterVectorRunEndFuncs(registry) + RegisterScalarSetLookup(registry) }) return registry } diff --git a/arrow/compute/scalar_set_lookup.go b/arrow/compute/scalar_set_lookup.go new file mode 100644 index 0000000..81971ce --- /dev/null +++ b/arrow/compute/scalar_set_lookup.go @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "context" + "errors" + "fmt" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/compute/exec" + "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels" + "github.com/apache/arrow-go/v18/arrow/extensions" + "github.com/apache/arrow-go/v18/internal/hashing" +) + +var ( + isinDoc = FunctionDoc{ + Summary: "Find each element in a set of values", + Description: `For each element in "values", return true if it is found +in a given set, false otherwise`, + ArgNames: []string{"values"}, + OptionsType: "SetOptions", + OptionsRequired: true, + } +) + +type NullMatchingBehavior = kernels.NullMatchingBehavior + +const ( + NullMatchingMatch = kernels.NullMatchingMatch + NullMatchingSkip = kernels.NullMatchingSkip + NullMatchingEmitNull = kernels.NullMatchingEmitNull + NullMatchingInconclusive = kernels.NullMatchingInconclusive +) + +type setLookupFunc struct { + ScalarFunction +} + +func (fn *setLookupFunc) Execute(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) { + return execInternal(ctx, fn, opts, -1, args...) +} + +func (fn *setLookupFunc) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) { + ensureDictionaryDecoded(vals...) + return fn.DispatchExact(vals...) +} + +type SetOptions struct { + ValueSet Datum + NullBehavior NullMatchingBehavior +} + +func (*SetOptions) TypeName() string { return "SetOptions" } + +func initSetLookup(ctx *exec.KernelCtx, args exec.KernelInitArgs) (exec.KernelState, error) { + if args.Options == nil { + return nil, fmt.Errorf("%w: calling a set lookup function without SetOptions", ErrInvalid) + } + + opts, ok := args.Options.(*SetOptions) + if !ok { + return nil, fmt.Errorf("%w: expected SetOptions, got %T", ErrInvalid, args.Options) + } + + valueset, ok := opts.ValueSet.(ArrayLikeDatum) + if !ok { + return nil, fmt.Errorf("%w: expected array-like datum, got %T", ErrInvalid, opts.ValueSet) + } + + argType := args.Inputs[0] + if (argType.ID() == arrow.STRING || argType.ID() == arrow.LARGE_STRING) && !arrow.IsBaseBinary(valueset.Type().ID()) { + // don't implicitly cast from a non-binary type to string + // since most types support casting to string and that may lead to + // surprises. However we do want most other implicit casts + return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, argType, valueset.Type()) + } + + if !arrow.TypeEqual(valueset.Type(), argType) { + result, err := CastDatum(ctx.Ctx, valueset, SafeCastOptions(argType)) + if err == nil { + defer result.Release() + valueset = result.(ArrayLikeDatum) + } else if CanCast(argType, valueset.Type()) { + // avoid casting from non-binary types to string like above + // otherwise will try to cast input array to valueset during + // execution + if (valueset.Type().ID() == arrow.STRING || valueset.Type().ID() == arrow.LARGE_STRING) && !arrow.IsBaseBinary(argType.ID()) { + return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, argType, valueset.Type()) + } + } else { + return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, argType, valueset.Type()) + } + + } + + internalOpts := kernels.SetLookupOptions{ + ValueSet: make([]exec.ArraySpan, 1), + TotalLen: opts.ValueSet.Len(), + NullBehavior: opts.NullBehavior, + } + + switch valueset.Kind() { + case KindArray: + internalOpts.ValueSet[0].SetMembers(valueset.(*ArrayDatum).Value) + internalOpts.ValueSetType = valueset.(*ArrayDatum).Type() + case KindChunked: + chnked := valueset.(*ChunkedDatum).Value + internalOpts.ValueSetType = chnked.DataType() + internalOpts.ValueSet = make([]exec.ArraySpan, len(chnked.Chunks())) + for i, c := range chnked.Chunks() { + internalOpts.ValueSet[i].SetMembers(c.Data()) + } + default: + return nil, fmt.Errorf("%w: expected array or chunked array, got %s", ErrInvalid, opts.ValueSet.Kind()) + } + + return kernels.CreateSetLookupState(internalOpts, exec.GetAllocator(ctx.Ctx)) +} + +type setLookupState interface { + Init(kernels.SetLookupOptions) error + ValueType() arrow.DataType +} + +func execIsIn(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + state := ctx.State.(setLookupState) + ctx.Kernel.(*exec.ScalarKernel).Data = state + in := batch.Values[0] + + if !arrow.TypeEqual(in.Type(), state.ValueType()) { + materialized := in.Array.MakeArray() + defer materialized.Release() + + castResult, err := CastArray(ctx.Ctx, materialized, SafeCastOptions(state.ValueType())) + if err != nil { + if errors.Is(err, arrow.ErrNotImplemented) { + return fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, in.Type(), state.ValueType()) + } + return err + } + defer castResult.Release() + + var casted exec.ArraySpan + casted.SetMembers(castResult.Data()) + return kernels.DispatchIsIn(state, &casted, out) + } + + return kernels.DispatchIsIn(state, &in.Array, out) +} + +func IsIn(ctx context.Context, opts SetOptions, values Datum) (Datum, error) { + return CallFunction(ctx, "is_in", &opts, values) +} + +func IsInSet(ctx context.Context, valueSet, values Datum) (Datum, error) { + return IsIn(ctx, SetOptions{ValueSet: valueSet}, values) +} + +func RegisterScalarSetLookup(reg FunctionRegistry) { + inBase := NewScalarFunction("is_in", Unary(), isinDoc) + + types := []exec.InputType{ + exec.NewMatchedInput(exec.Primitive()), + exec.NewIDInput(arrow.DECIMAL32), + exec.NewIDInput(arrow.DECIMAL64), + } + + outType := exec.NewOutputType(arrow.FixedWidthTypes.Boolean) + for _, ty := range types { + kn := exec.NewScalarKernel([]exec.InputType{ty}, outType, execIsIn, initSetLookup) + kn.MemAlloc = exec.MemPrealloc + kn.NullHandling = exec.NullComputedPrealloc + if err := inBase.AddKernel(kn); err != nil { + panic(err) + } + } + + binaryTypes := []exec.InputType{ + exec.NewMatchedInput(exec.BinaryLike()), + exec.NewMatchedInput(exec.LargeBinaryLike()), + exec.NewExactInput(extensions.NewUUIDType()), + exec.NewIDInput(arrow.FIXED_SIZE_BINARY), + exec.NewIDInput(arrow.DECIMAL128), + exec.NewIDInput(arrow.DECIMAL256), + } + for _, ty := range binaryTypes { + kn := exec.NewScalarKernel([]exec.InputType{ty}, outType, execIsIn, initSetLookup) + kn.MemAlloc = exec.MemPrealloc + kn.NullHandling = exec.NullComputedPrealloc + kn.CleanupFn = func(state exec.KernelState) error { + s := state.(*kernels.SetLookupState[[]byte]) + s.Lookup.(*hashing.BinaryMemoTable).Release() + return nil + } + + if err := inBase.AddKernel(kn); err != nil { + panic(err) + } + } + + reg.AddFunction(&setLookupFunc{*inBase}, false) +} diff --git a/arrow/compute/scalar_set_lookup_test.go b/arrow/compute/scalar_set_lookup_test.go new file mode 100644 index 0000000..770b984 --- /dev/null +++ b/arrow/compute/scalar_set_lookup_test.go @@ -0,0 +1,606 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "context" + "strings" + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/compute" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/suite" +) + +type ScalarSetLookupSuite struct { + suite.Suite + + mem *memory.CheckedAllocator + ctx context.Context +} + +func (ss *ScalarSetLookupSuite) SetupTest() { + ss.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + ss.ctx = compute.WithAllocator(context.TODO(), ss.mem) +} + +func (ss *ScalarSetLookupSuite) getArr(dt arrow.DataType, str string) arrow.Array { + arr, _, err := array.FromJSON(ss.mem, dt, strings.NewReader(str), array.WithUseNumber()) + ss.Require().NoError(err) + return arr +} + +func (ss *ScalarSetLookupSuite) checkIsIn(input, valueSet arrow.Array, expectedJSON string, matching compute.NullMatchingBehavior) { + expected := ss.getArr(arrow.FixedWidthTypes.Boolean, expectedJSON) + defer expected.Release() + + result, err := compute.IsIn(ss.ctx, compute.SetOptions{ + ValueSet: compute.NewDatumWithoutOwning(valueSet), + NullBehavior: matching, + }, compute.NewDatumWithoutOwning(input)) + ss.Require().NoError(err) + defer result.Release() + + assertDatumsEqual(ss.T(), compute.NewDatumWithoutOwning(expected), result, nil, nil) +} + +func (ss *ScalarSetLookupSuite) checkIsInFromJSON(typ arrow.DataType, input, valueSet, expected string, matching compute.NullMatchingBehavior) { + inputArr := ss.getArr(typ, input) + defer inputArr.Release() + + valueSetArr := ss.getArr(typ, valueSet) + defer valueSetArr.Release() + + ss.checkIsIn(inputArr, valueSetArr, expected, matching) +} + +func (ss *ScalarSetLookupSuite) checkIsInDictionary(typ, idxType arrow.DataType, inputDict, inputIndex, valueSet, expected string, matching compute.NullMatchingBehavior) { + dictType := &arrow.DictionaryType{IndexType: idxType, ValueType: typ} + indices := ss.getArr(idxType, inputIndex) + defer indices.Release() + dict := ss.getArr(typ, inputDict) + defer dict.Release() + + input := array.NewDictionaryArray(dictType, indices, dict) + defer input.Release() + + valueSetArr := ss.getArr(typ, valueSet) + defer valueSetArr.Release() + + ss.checkIsIn(input, valueSetArr, expected, matching) +} + +func (ss *ScalarSetLookupSuite) checkIsInChunked(input, value, expected *arrow.Chunked, matching compute.NullMatchingBehavior) { + result, err := compute.IsIn(ss.ctx, compute.SetOptions{ + ValueSet: compute.NewDatumWithoutOwning(value), + NullBehavior: matching, + }, compute.NewDatumWithoutOwning(input)) + ss.Require().NoError(err) + defer result.Release() + + ss.Len(result.(*compute.ChunkedDatum).Chunks(), 1) + assertDatumsEqual(ss.T(), compute.NewDatumWithoutOwning(expected), result, nil, nil) +} + +func (ss *ScalarSetLookupSuite) TestIsInPrimitive() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"no nulls", `[0, 1, 2, 3, 2]`, `[2, 1]`, []testCase{ + {`[false, true, true, false, true]`, compute.NullMatchingMatch}, + }}, + {"nulls in left", `[null, 1, 2, 3, 2]`, `[2, 1]`, []testCase{ + {`[false, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[null, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, false, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls in right", `[0, 1, 2, 3, 2]`, `[2, null, 1]`, []testCase{ + {`[false, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[false, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls in both", `[null, 1, 2, 3, 2]`, `[2, null, 1]`, []testCase{ + {`[true, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[null, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right", `[null, 1, 2, 3, 2]`, `[null, 2, 2, null, 1, 1]`, []testCase{ + {`[true, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[null, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"empty arrays", `[]`, `[]`, []testCase{ + {`[]`, compute.NullMatchingMatch}, + }}, + } + + typList := append([]arrow.DataType{}, numericTypes...) + typList = append(typList, arrow.FixedWidthTypes.Time32s, + arrow.FixedWidthTypes.Time32ms, arrow.FixedWidthTypes.Time64us, + arrow.FixedWidthTypes.Time64ns, arrow.FixedWidthTypes.Timestamp_us, + arrow.FixedWidthTypes.Timestamp_ns, arrow.FixedWidthTypes.Duration_s, + arrow.FixedWidthTypes.Duration_ms, arrow.FixedWidthTypes.Duration_us, + arrow.FixedWidthTypes.Duration_ns) + + for _, typ := range typList { + ss.Run(typ.String(), func() { + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestDurationCasts() { + vals := ss.getArr(arrow.FixedWidthTypes.Duration_s, `[0, 1, 2]`) + defer vals.Release() + + valueset := ss.getArr(arrow.FixedWidthTypes.Duration_ms, `[1, 2, 2000]`) + defer valueset.Release() + + ss.checkIsIn(vals, valueset, `[false, false, true]`, compute.NullMatchingMatch) +} + +func (ss *ScalarSetLookupSuite) TestIsInBinary() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"nulls on left", `["YWFh", "", "Y2M=", null, ""]`, `["YWFh", ""]`, []testCase{ + {`[true, true, false, false, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, false, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls on right", `["YWFh", "", "Y2M=", null, ""]`, `["YWFh", "", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right array", `["YWFh", "", "Y2M=", null, ""]`, `[null, "YWFh", "YWFh", "", "", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + } + + for _, typ := range baseBinaryTypes { + ss.Run(typ.String(), func() { + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInFixedSizeBinary() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"nulls on left", `["YWFh", "YmJi", "Y2Nj", null, "YmJi"]`, `["YWFh", "YmJi"]`, []testCase{ + {`[true, true, false, false, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, false, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls on right", `["YWFh", "YmJi", "Y2Nj", null, "YmJi"]`, `["YWFh", "YmJi", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right array", `["YWFh", "YmJi", "Y2Nj", null, "YmJi"]`, `["YWFh", null, "YWFh", "YmJi", "YmJi", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + } + + typ := &arrow.FixedSizeBinaryType{ByteWidth: 3} + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInDecimal() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"nulls on left", `["12.3", "45.6", "78.9", null, "12.3"]`, `["12.3", "78.9"]`, []testCase{ + {`[true, false, true, false, true]`, compute.NullMatchingMatch}, + {`[true, false, true, false, true]`, compute.NullMatchingSkip}, + {`[true, false, true, null, true]`, compute.NullMatchingEmitNull}, + {`[true, false, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls on right", `["12.3", "45.6", "78.9", null, "12.3"]`, `["12.3", "78.9", null]`, []testCase{ + {`[true, false, true, true, true]`, compute.NullMatchingMatch}, + {`[true, false, true, false, true]`, compute.NullMatchingSkip}, + {`[true, false, true, null, true]`, compute.NullMatchingEmitNull}, + {`[true, null, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right array", `["12.3", "45.6", "78.9", null, "12.3"]`, `[null, "12.3", "12.3", "78.9", "78.9", null]`, []testCase{ + {`[true, false, true, true, true]`, compute.NullMatchingMatch}, + {`[true, false, true, false, true]`, compute.NullMatchingSkip}, + {`[true, false, true, null, true]`, compute.NullMatchingEmitNull}, + {`[true, null, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + } + + decTypes := []arrow.DataType{ + &arrow.Decimal32Type{Precision: 3, Scale: 1}, + &arrow.Decimal64Type{Precision: 3, Scale: 1}, + &arrow.Decimal128Type{Precision: 3, Scale: 1}, + &arrow.Decimal256Type{Precision: 3, Scale: 1}, + } + + for _, typ := range decTypes { + ss.Run(typ.String(), func() { + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } + + // don't yet have Decimal32 or Decimal64 implemented for casting + if typ.ID() == arrow.DECIMAL128 || typ.ID() == arrow.DECIMAL256 { + // test cast + in := ss.getArr(&arrow.Decimal128Type{Precision: 4, Scale: 2}, `["12.30", "45.60", "78.90"]`) + defer in.Release() + values := ss.getArr(typ, `["12.3", "78.9"]`) + defer values.Release() + + ss.checkIsIn(in, values, `[true, false true]`, compute.NullMatchingMatch) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInDictionary() { + tests := []struct { + typ arrow.DataType + inputDict string + inputIdx string + valueSet string + expected string + matching compute.NullMatchingBehavior + }{ + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 2, null, 0]`, + valueSet: `["A", "B", "C"]`, + expected: `[true, true, false, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.PrimitiveTypes.Float32, + inputDict: `[4.1, -1.0, 42, 9.8]`, + inputIdx: `[1, 2, null, 0]`, + valueSet: `[4.1, 42, -1.0]`, + expected: `[true, true, false, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, true, true, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, true, true, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[false, false, false, true, false]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, false, true, true]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[false, false, false, true, false]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[false, false, false, true, false]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, null, true, true]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[null, false, null, true, null]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[null, false, null, true, null]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, null, null, true, true]`, + matching: compute.NullMatchingInconclusive, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[null, null, null, true, null]`, + matching: compute.NullMatchingInconclusive, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[null, false, null, true, null]`, + matching: compute.NullMatchingInconclusive, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 2, null, 0]`, + valueSet: `["A", "A", "B", "A", "B", "C"]`, + expected: `[true, true, false, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, false, true, true, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, false, false, true, true]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, false, null, true, true]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, null, null, true, true]`, + matching: compute.NullMatchingInconclusive, + }, + } + + for _, ty := range dictIndexTypes { + ss.Run("idx="+ty.String(), func() { + for _, test := range tests { + ss.Run(test.typ.String(), func() { + ss.checkIsInDictionary(test.typ, ty, + test.inputDict, test.inputIdx, test.valueSet, + test.expected, test.matching) + }) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInChunked() { + input, err := array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["abc", "def", "", "abc", "jkl"]`, `["def", null, "abc", "zzz"]`}) + ss.Require().NoError(err) + defer input.Release() + + valueSet, err := array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["", "def"]`, `["abc"]`}) + ss.Require().NoError(err) + defer valueSet.Release() + + expected, err := array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[true, true, true, true, false]`, `[true, false, true, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingMatch) + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingSkip) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[true, true, true, true, false]`, `[true, null, true, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingEmitNull) + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingInconclusive) + + valueSet, err = array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["", "def"]`, `[null]`}) + ss.Require().NoError(err) + defer valueSet.Release() + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, true, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingMatch) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, false, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingSkip) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, null, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingEmitNull) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[null, true, true, null, null]`, `[true, null, null, null]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingInconclusive) + + valueSet, err = array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["", null, "", "def"]`, `["def", null]`}) + ss.Require().NoError(err) + defer valueSet.Release() + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, true, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingMatch) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, false, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingSkip) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, null, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingEmitNull) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[null, true, true, null, null]`, `[true, null, null, null]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingInconclusive) +} + +func (ss *ScalarSetLookupSuite) TearDownTest() { + ss.mem.AssertSize(ss.T(), 0) +} + +func TestScalarSetLookup(t *testing.T) { + suite.Run(t, new(ScalarSetLookupSuite)) +} diff --git a/arrow/scalar/append.go b/arrow/scalar/append.go index 0525bc8..737e800 100644 --- a/arrow/scalar/append.go +++ b/arrow/scalar/append.go @@ -76,41 +76,32 @@ func appendBinary(bldr binaryBuilder, scalars []Scalar) { } } -// Append requires the passed in builder and scalar to have the same datatype -// otherwise it will return an error. Will return arrow.ErrNotImplemented if -// the type hasn't been implemented for this. -// -// NOTE only available in go1.18+ -func Append(bldr array.Builder, s Scalar) error { - return AppendSlice(bldr, []Scalar{s}) +type extbuilder interface { + array.Builder + StorageBuilder() array.Builder } -// AppendSlice requires the passed in builder and all scalars in the slice -// to have the same datatype otherwise it will return an error. Will return -// arrow.ErrNotImplemented if the type hasn't been implemented for this. -// -// NOTE only available in go1.18+ -func AppendSlice(bldr array.Builder, scalars []Scalar) error { +func appendToBldr(bldr array.Builder, scalars []Scalar) error { if len(scalars) == 0 { return nil } ty := bldr.Type() - for _, sc := range scalars { - if !arrow.TypeEqual(ty, sc.DataType()) { - return fmt.Errorf("%w: cannot append scalar of type %s to builder for type %s", - arrow.ErrInvalid, scalars[0].DataType(), bldr.Type()) - } - } - bldr.Reserve(len(scalars)) switch bldr := bldr.(type) { + case extbuilder: + baseScalars := make([]Scalar, len(scalars)) + for i, sc := range scalars { + baseScalars[i] = sc.(*Extension).Value + } + + return appendToBldr(bldr.StorageBuilder(), baseScalars) case *array.BooleanBuilder: - appendPrimitive[bool](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Decimal128Builder: - appendPrimitive[decimal128.Num](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Decimal256Builder: - appendPrimitive[decimal256.Num](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.FixedSizeBinaryBuilder: for _, sc := range scalars { s := sc.(*FixedSizeBinary) @@ -121,45 +112,45 @@ func AppendSlice(bldr array.Builder, scalars []Scalar) error { } } case *array.Int8Builder: - appendPrimitive[int8](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint8Builder: - appendPrimitive[uint8](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Int16Builder: - appendPrimitive[int16](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint16Builder: - appendPrimitive[uint16](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Int32Builder: - appendPrimitive[int32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint32Builder: - appendPrimitive[uint32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Int64Builder: - appendPrimitive[int64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint64Builder: - appendPrimitive[uint64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Float16Builder: - appendPrimitive[float16.Num](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Float32Builder: - appendPrimitive[float32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Float64Builder: - appendPrimitive[float64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Date32Builder: - appendPrimitive[arrow.Date32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Date64Builder: - appendPrimitive[arrow.Date64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Time32Builder: - appendPrimitive[arrow.Time32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Time64Builder: - appendPrimitive[arrow.Time64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.DayTimeIntervalBuilder: - appendPrimitive[arrow.DayTimeInterval](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.MonthIntervalBuilder: - appendPrimitive[arrow.MonthInterval](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.MonthDayNanoIntervalBuilder: - appendPrimitive[arrow.MonthDayNanoInterval](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.DurationBuilder: - appendPrimitive[arrow.Duration](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.TimestampBuilder: - appendPrimitive[arrow.Timestamp](bldr, scalars) + appendPrimitive(bldr, scalars) case array.StringLikeBuilder: appendBinary(bldr, scalars) case *array.BinaryBuilder: @@ -261,3 +252,33 @@ func AppendSlice(bldr array.Builder, scalars []Scalar) error { return nil } + +// Append requires the passed in builder and scalar to have the same datatype +// otherwise it will return an error. Will return arrow.ErrNotImplemented if +// the type hasn't been implemented for this. +// +// NOTE only available in go1.18+ +func Append(bldr array.Builder, s Scalar) error { + return AppendSlice(bldr, []Scalar{s}) +} + +// AppendSlice requires the passed in builder and all scalars in the slice +// to have the same datatype otherwise it will return an error. Will return +// arrow.ErrNotImplemented if the type hasn't been implemented for this. +// +// NOTE only available in go1.18+ +func AppendSlice(bldr array.Builder, scalars []Scalar) error { + if len(scalars) == 0 { + return nil + } + + ty := bldr.Type() + for _, sc := range scalars { + if !arrow.TypeEqual(ty, sc.DataType()) { + return fmt.Errorf("%w: cannot append scalar of type %s to builder for type %s", + arrow.ErrInvalid, scalars[0].DataType(), bldr.Type()) + } + } + + return appendToBldr(bldr, scalars) +} diff --git a/internal/hashing/xxh3_memo_table.gen.go b/internal/hashing/xxh3_memo_table.gen.go index e99a4f8..5f105f6 100644 --- a/internal/hashing/xxh3_memo_table.gen.go +++ b/internal/hashing/xxh3_memo_table.gen.go @@ -267,6 +267,11 @@ func (s *Int8MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int8MemoTable) Exists(val int8) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int8MemoTable) Get(val interface{}) (int, bool) { @@ -282,10 +287,13 @@ func (s *Int8MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int8)) +} - h := hashInt(uint64(val.(int8)), 0) +func (s *Int8MemoTable) InsertOrGet(val int8) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int8) bool { - return val.(int8) == v + return val == v }) if ok { @@ -293,7 +301,7 @@ func (s *Int8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err e found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int8), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -544,6 +552,11 @@ func (s *Uint8MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint8MemoTable) Exists(val uint8) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint8MemoTable) Get(val interface{}) (int, bool) { @@ -559,10 +572,13 @@ func (s *Uint8MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint8)) +} - h := hashInt(uint64(val.(uint8)), 0) +func (s *Uint8MemoTable) InsertOrGet(val uint8) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint8) bool { - return val.(uint8) == v + return val == v }) if ok { @@ -570,7 +586,7 @@ func (s *Uint8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint8), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -821,6 +837,11 @@ func (s *Int16MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int16MemoTable) Exists(val int16) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int16MemoTable) Get(val interface{}) (int, bool) { @@ -836,10 +857,13 @@ func (s *Int16MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int16)) +} - h := hashInt(uint64(val.(int16)), 0) +func (s *Int16MemoTable) InsertOrGet(val int16) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int16) bool { - return val.(int16) == v + return val == v }) if ok { @@ -847,7 +871,7 @@ func (s *Int16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int16), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1098,6 +1122,11 @@ func (s *Uint16MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint16MemoTable) Exists(val uint16) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint16MemoTable) Get(val interface{}) (int, bool) { @@ -1113,10 +1142,13 @@ func (s *Uint16MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint16)) +} - h := hashInt(uint64(val.(uint16)), 0) +func (s *Uint16MemoTable) InsertOrGet(val uint16) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint16) bool { - return val.(uint16) == v + return val == v }) if ok { @@ -1124,7 +1156,7 @@ func (s *Uint16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint16), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1375,6 +1407,11 @@ func (s *Int32MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int32MemoTable) Exists(val int32) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int32MemoTable) Get(val interface{}) (int, bool) { @@ -1390,10 +1427,13 @@ func (s *Int32MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int32)) +} - h := hashInt(uint64(val.(int32)), 0) +func (s *Int32MemoTable) InsertOrGet(val int32) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int32) bool { - return val.(int32) == v + return val == v }) if ok { @@ -1401,7 +1441,7 @@ func (s *Int32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int32), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1652,6 +1692,11 @@ func (s *Int64MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int64MemoTable) Exists(val int64) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int64MemoTable) Get(val interface{}) (int, bool) { @@ -1667,10 +1712,13 @@ func (s *Int64MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int64)) +} - h := hashInt(uint64(val.(int64)), 0) +func (s *Int64MemoTable) InsertOrGet(val int64) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int64) bool { - return val.(int64) == v + return val == v }) if ok { @@ -1678,7 +1726,7 @@ func (s *Int64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int64), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1929,6 +1977,11 @@ func (s *Uint32MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint32MemoTable) Exists(val uint32) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint32MemoTable) Get(val interface{}) (int, bool) { @@ -1944,10 +1997,13 @@ func (s *Uint32MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint32)) +} - h := hashInt(uint64(val.(uint32)), 0) +func (s *Uint32MemoTable) InsertOrGet(val uint32) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint32) bool { - return val.(uint32) == v + return val == v }) if ok { @@ -1955,7 +2011,7 @@ func (s *Uint32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint32), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -2206,6 +2262,11 @@ func (s *Uint64MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint64MemoTable) Exists(val uint64) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint64MemoTable) Get(val interface{}) (int, bool) { @@ -2221,10 +2282,13 @@ func (s *Uint64MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint64)) +} - h := hashInt(uint64(val.(uint64)), 0) +func (s *Uint64MemoTable) InsertOrGet(val uint64) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint64) bool { - return val.(uint64) == v + return val == v }) if ok { @@ -2232,7 +2296,7 @@ func (s *Uint64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint64), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -2483,6 +2547,11 @@ func (s *Float32MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Float32MemoTable) Exists(val float32) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Float32MemoTable) Get(val interface{}) (int, bool) { @@ -2508,19 +2577,23 @@ func (s *Float32MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Float32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(float32)) +} + +func (s *Float32MemoTable) InsertOrGet(val float32) (idx int, found bool, err error) { var cmp func(float32) bool - if math.IsNaN(float64(val.(float32))) { + if math.IsNaN(float64(val)) { cmp = isNan32Cmp // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = float32(math.NaN()) } else { - cmp = func(v float32) bool { return val.(float32) == v } + cmp = func(v float32) bool { return val == v } } - h := hashFloat32(val.(float32), 0) + h := hashFloat32(val, 0) e, ok := s.tbl.Lookup(h, cmp) if ok { @@ -2528,7 +2601,7 @@ func (s *Float32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, er found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(float32), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -2779,6 +2852,11 @@ func (s *Float64MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Float64MemoTable) Exists(val float64) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Float64MemoTable) Get(val interface{}) (int, bool) { @@ -2803,18 +2881,22 @@ func (s *Float64MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Float64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(float64)) +} + +func (s *Float64MemoTable) InsertOrGet(val float64) (idx int, found bool, err error) { var cmp func(float64) bool - if math.IsNaN(val.(float64)) { + if math.IsNaN(val) { cmp = math.IsNaN // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = math.NaN() } else { - cmp = func(v float64) bool { return val.(float64) == v } + cmp = func(v float64) bool { return val == v } } - h := hashFloat64(val.(float64), 0) + h := hashFloat64(val, 0) e, ok := s.tbl.Lookup(h, cmp) if ok { @@ -2822,7 +2904,7 @@ func (s *Float64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, er found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(float64), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } diff --git a/internal/hashing/xxh3_memo_table.gen.go.tmpl b/internal/hashing/xxh3_memo_table.gen.go.tmpl index 9ba35c7..14a8f21 100644 --- a/internal/hashing/xxh3_memo_table.gen.go.tmpl +++ b/internal/hashing/xxh3_memo_table.gen.go.tmpl @@ -267,6 +267,11 @@ func (s *{{.Name}}MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *{{.Name}}MemoTable) Exists(val {{.name}}) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *{{.Name}}MemoTable) Get(val interface{}) (int, bool) { @@ -304,31 +309,35 @@ func (s *{{.Name}}MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *{{.Name}}MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { - {{if and (ne .Name "Float32") (ne .Name "Float64") }} - h := hashInt(uint64(val.({{.name}})), 0) + return s.InsertOrGet(val.({{.name}})) +} + +func (s *{{.Name}}MemoTable) InsertOrGet(val {{.name}}) (idx int, found bool, err error) { + {{if and (ne .Name "Float32") (ne .Name "Float64") -}} + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v {{.name}}) bool { - return val.({{.name}}) == v + return val == v }) {{ else }} var cmp func({{.name}}) bool {{if eq .Name "Float32"}} - if math.IsNaN(float64(val.(float32))) { + if math.IsNaN(float64(val)) { cmp = isNan32Cmp // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = float32(math.NaN()) {{ else -}} - if math.IsNaN(val.(float64)) { + if math.IsNaN(val) { cmp = math.IsNaN // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = math.NaN() {{end -}} } else { - cmp = func(v {{.name}}) bool { return val.({{.name}}) == v } + cmp = func(v {{.name}}) bool { return val == v } } - h := hash{{.Name}}(val.({{.name}}), 0) + h := hash{{.Name}}(val, 0) e, ok := s.tbl.Lookup(h, cmp) {{ end }} if ok { @@ -336,7 +345,7 @@ func (s *{{.Name}}MemoTable) GetOrInsert(val interface{}) (idx int, found bool, found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.({{.name}}), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } diff --git a/internal/hashing/xxh3_memo_table.go b/internal/hashing/xxh3_memo_table.go index fbb8b33..f10a9b2 100644 --- a/internal/hashing/xxh3_memo_table.go +++ b/internal/hashing/xxh3_memo_table.go @@ -74,6 +74,18 @@ type MemoTable interface { WriteOutSubset(offset int, out []byte) } +type MemoTypes interface { + int8 | int16 | int32 | int64 | + uint8 | uint16 | uint32 | uint64 | + float32 | float64 | []byte +} + +type TypedMemoTable[T MemoTypes] interface { + MemoTable + Exists(T) bool + InsertOrGet(val T) (idx int, found bool, err error) +} + type NumericMemoTable interface { MemoTable WriteOutLE(out []byte) @@ -202,25 +214,17 @@ func (BinaryMemoTable) getHash(val interface{}) uint64 { } } -// helper function to append the given value to the builder regardless -// of the underlying binary type. -func (b *BinaryMemoTable) appendVal(val interface{}) { - switch v := val.(type) { - case string: - b.builder.AppendString(v) - case []byte: - b.builder.Append(v) - case ByteSlice: - b.builder.Append(v.Bytes()) - } -} - func (b *BinaryMemoTable) lookup(h uint64, val []byte) (*entryInt32, bool) { return b.tbl.Lookup(h, func(i int32) bool { return bytes.Equal(val, b.builder.Value(int(i))) }) } +func (b *BinaryMemoTable) Exists(val []byte) bool { + _, ok := b.lookup(b.getHash(val), val) + return ok +} + // Get returns the index of the specified value in the table or KeyNotFound, // and a boolean indicating whether it was found in the table. func (b *BinaryMemoTable) Get(val interface{}) (int, bool) { @@ -246,17 +250,21 @@ func (b *BinaryMemoTable) GetOrInsertBytes(val []byte) (idx int, found bool, err return } +func (b *BinaryMemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return b.InsertOrGet(b.valAsByteSlice(val)) +} + // GetOrInsert returns the index of the given value in the table, if not found // it is inserted into the table. The return value 'found' indicates whether the value // was found in the table (true) or inserted (false) along with any possible error. -func (b *BinaryMemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { +func (b *BinaryMemoTable) InsertOrGet(val []byte) (idx int, found bool, err error) { h := b.getHash(val) - p, found := b.lookup(h, b.valAsByteSlice(val)) + p, found := b.lookup(h, val) if found { idx = int(p.payload.val) } else { idx = b.Size() - b.appendVal(val) + b.builder.Append(val) b.tbl.Insert(p, h, int32(idx), -1) } return diff --git a/parquet/file/file_writer.go b/parquet/file/file_writer.go index f616831..3c5608a 100644 --- a/parquet/file/file_writer.go +++ b/parquet/file/file_writer.go @@ -213,11 +213,14 @@ func (fw *Writer) Close() (err error) { }() err = fw.FlushWithFooter() - fw.metadata.Clear() } return nil } +func (fw *Writer) FileMetadata() (*metadata.FileMetaData, error) { + return fw.metadata.Snapshot() +} + // FlushWithFooter closes any open row group writer and writes the file footer, leaving // the writer open for additional row groups. Additional footers written by later // calls to FlushWithFooter or Close will be cumulative, so that only the last footer diff --git a/parquet/pqarrow/file_writer.go b/parquet/pqarrow/file_writer.go index 45cfe49..f3e65ac 100644 --- a/parquet/pqarrow/file_writer.go +++ b/parquet/pqarrow/file_writer.go @@ -338,3 +338,7 @@ func (fw *FileWriter) WriteColumnData(data arrow.Array) error { defer chunked.Release() return fw.WriteColumnChunked(chunked, 0, int64(data.Len())) } + +func (fw *FileWriter) FileMetadata() (*metadata.FileMetaData, error) { + return fw.wr.FileMetadata() +} diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go index 416d59f..0342f28 100644 --- a/parquet/pqarrow/schema.go +++ b/parquet/pqarrow/schema.go @@ -240,7 +240,7 @@ func repFromNullable(isnullable bool) parquet.Repetition { return parquet.Repetitions.Required } -func structToNode(typ *arrow.StructType, name string, nullable bool, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { +func structToNode(typ *arrow.StructType, name string, nullable bool, fieldID int32, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { if typ.NumFields() == 0 { return nil, fmt.Errorf("cannot write struct type '%s' with no children field to parquet. Consider adding a dummy child", name) } @@ -254,7 +254,7 @@ func structToNode(typ *arrow.StructType, name string, nullable bool, props *parq children = append(children, n) } - return schema.NewGroupNode(name, repFromNullable(nullable), children, -1) + return schema.NewGroupNode(name, repFromNullable(nullable), children, fieldID) } func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { @@ -267,7 +267,7 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties return nil, xerrors.New("nulltype arrow field must be nullable") } case arrow.STRUCT: - return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, props, arrprops) + return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, fieldIDFromMeta(field.Metadata), props, arrprops) case arrow.FIXED_SIZE_LIST, arrow.LIST: elemField := field.Type.(arrow.ListLikeType).ElemField() @@ -276,7 +276,7 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties return nil, err } - return schema.ListOfWithName(name, child, repFromNullable(field.Nullable), -1) + return schema.ListOfWithName(name, child, repFromNullable(field.Nullable), fieldIDFromMeta(field.Metadata)) case arrow.DICTIONARY: // parquet has no dictionary type, dictionary is encoding, not schema level dictType := field.Type.(*arrow.DictionaryType) @@ -302,9 +302,9 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties } return schema.NewGroupNode(field.Name, repFromNullable(field.Nullable), schema.FieldList{ keyvalNode, - }, -1) + }, fieldIDFromMeta(field.Metadata)) } - return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), -1) + return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), fieldIDFromMeta(field.Metadata)) } // Not a GroupNode