This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new f0c5d99 feat(arrow/compute): make is_nan dispatchable (#177)
f0c5d99 is described below
commit f0c5d9939d3f0835ea0531639b1fd85b362787bf
Author: Matt Topol <[email protected]>
AuthorDate: Tue Oct 29 11:38:21 2024 -0400
feat(arrow/compute): make is_nan dispatchable (#177)
Currently the `is_nan` compute kernel is a `MetaFunction` which cannot
be dispatched via kernel dispatch making it only usable via calling it
directly with `CallFunction`. By shifting it to be a proper function
instead of a `MetaFunction` this improves its compatibility and makes it
able to be dispatched and thus called through the substrait interface in
the `exprs` package.
---
.../compute/internal/kernels/scalar_comparisons.go | 52 ++++++++++++++++++++++
arrow/compute/scalar_compare.go | 39 +++-------------
arrow/compute/scalar_compare_test.go | 4 +-
3 files changed, 61 insertions(+), 34 deletions(-)
diff --git a/arrow/compute/internal/kernels/scalar_comparisons.go
b/arrow/compute/internal/kernels/scalar_comparisons.go
index b30605b..e4a5054 100644
--- a/arrow/compute/internal/kernels/scalar_comparisons.go
+++ b/arrow/compute/internal/kernels/scalar_comparisons.go
@@ -745,3 +745,55 @@ func IsNullNotNullKernels() []exec.ScalarKernel {
return results
}
+
+func ConstBoolExec(val bool) func(*exec.KernelCtx, *exec.ExecSpan,
*exec.ExecResult) error {
+ return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ bitutil.SetBitsTo(out.Buffers[1].Buf, out.Offset, batch.Len,
val)
+ return nil
+ }
+}
+
+func isNanKernelExec[T float32 | float64](ctx *exec.KernelCtx, batch
*exec.ExecSpan, out *exec.ExecResult) error {
+ kn := ctx.Kernel.(*exec.ScalarKernel)
+ knData := kn.Data.(CompareFuncData).Funcs()
+
+ outPrefix := int(out.Offset % 8)
+ outBuf := out.Buffers[1].Buf[out.Offset/8:]
+
+ inputBytes := getOffsetSpanBytes(&batch.Values[0].Array)
+ knData.funcAA(inputBytes, inputBytes, outBuf, outPrefix)
+ return nil
+}
+
+func IsNaNKernels() []exec.ScalarKernel {
+ outputType := exec.NewOutputType(arrow.FixedWidthTypes.Boolean)
+
+ knFloat32 :=
exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float32)},
+ outputType, isNanKernelExec[float32], nil)
+ knFloat32.Data = genCompareKernel[float32](CmpNE)
+ knFloat32.NullHandling = exec.NullNoOutput
+ knFloat64 :=
exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float64)},
+ outputType, isNanKernelExec[float64], nil)
+ knFloat64.Data = genCompareKernel[float64](CmpNE)
+ knFloat64.NullHandling = exec.NullNoOutput
+
+ kernels := []exec.ScalarKernel{knFloat32, knFloat64}
+
+ for _, dt := range intTypes {
+ kn := exec.NewScalarKernel(
+ []exec.InputType{exec.NewExactInput(dt)},
+ outputType, ConstBoolExec(false), nil)
+ kn.NullHandling = exec.NullNoOutput
+ kernels = append(kernels, kn)
+ }
+
+ for _, id := range []arrow.Type{arrow.NULL, arrow.DURATION,
arrow.DECIMAL32, arrow.DECIMAL64, arrow.DECIMAL128, arrow.DECIMAL256} {
+ kn := exec.NewScalarKernel(
+ []exec.InputType{exec.NewIDInput(id)},
+ outputType, ConstBoolExec(false), nil)
+ kn.NullHandling = exec.NullNoOutput
+ kernels = append(kernels, kn)
+ }
+
+ return kernels
+}
diff --git a/arrow/compute/scalar_compare.go b/arrow/compute/scalar_compare.go
index cfead2a..0e853a6 100644
--- a/arrow/compute/scalar_compare.go
+++ b/arrow/compute/scalar_compare.go
@@ -20,12 +20,10 @@ package compute
import (
"context"
- "fmt"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
- "github.com/apache/arrow-go/v18/arrow/scalar"
)
type compareFunction struct {
@@ -152,34 +150,11 @@ func RegisterScalarComparisons(reg FunctionRegistry) {
reg.AddFunction(isNullFn, false)
reg.AddFunction(isNotNullFn, false)
- reg.AddFunction(NewMetaFunction("is_nan", Unary(), EmptyFuncDoc,
- func(ctx context.Context, opts FunctionOptions, args ...Datum)
(Datum, error) {
- type hasType interface {
- Type() arrow.DataType
- }
-
- // only Scalar, Array and ChunkedArray have a Type
method
- arg, ok := args[0].(hasType)
- if !ok {
- // don't support Table/Record/None kinds
- return nil, fmt.Errorf("%w: unsupported type
for is_nan %s",
- arrow.ErrNotImplemented, args[0])
- }
-
- switch arg.Type() {
- case arrow.PrimitiveTypes.Float32,
arrow.PrimitiveTypes.Float64:
- return CallFunction(ctx, "not_equal", nil,
args[0], args[0])
- default:
- if arg, ok := args[0].(ArrayLikeDatum); ok {
- result, err :=
scalar.MakeArrayFromScalar(scalar.NewBooleanScalar(false),
- int(arg.Len()),
GetAllocator(ctx))
- if err != nil {
- return nil, err
- }
- return NewDatumWithoutOwning(result),
nil
- }
-
- return NewDatum(false), nil
- }
- }), false)
+ isNaNFn := &compareFunction{*NewScalarFunction("is_nan", Unary(),
EmptyFuncDoc)}
+ for _, k := range kernels.IsNaNKernels() {
+ if err := isNaNFn.AddKernel(k); err != nil {
+ panic(err)
+ }
+ }
+ reg.AddFunction(isNaNFn, false)
}
diff --git a/arrow/compute/scalar_compare_test.go
b/arrow/compute/scalar_compare_test.go
index b0c9ab9..e45b3af 100644
--- a/arrow/compute/scalar_compare_test.go
+++ b/arrow/compute/scalar_compare_test.go
@@ -1497,8 +1497,8 @@ func (sv *ScalarValiditySuite) TestIsNaN() {
}{
{`[]`, `[]`},
{`[1]`, `[false]`},
- {`[null]`, `[null]`},
- {`["NaN", 1, 0, null]`, `[true, false, false, null]`},
+ {`[null]`, `[false]`},
+ {`["NaN", 1, 0, null]`, `[true, false, false, false]`},
}
for _, typ := range floatingTypes {