This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new f0c5d99  feat(arrow/compute): make is_nan dispatchable (#177)
f0c5d99 is described below

commit f0c5d9939d3f0835ea0531639b1fd85b362787bf
Author: Matt Topol <[email protected]>
AuthorDate: Tue Oct 29 11:38:21 2024 -0400

    feat(arrow/compute): make is_nan dispatchable (#177)
    
    Currently the `is_nan` compute kernel is a `MetaFunction` which cannot
    be dispatched via kernel dispatch making it only usable via calling it
    directly with `CallFunction`. By shifting it to be a proper function
    instead of a `MetaFunction` this improves its compatibility and makes it
    able to be dispatched and thus called through the substrait interface in
    the `exprs` package.
---
 .../compute/internal/kernels/scalar_comparisons.go | 52 ++++++++++++++++++++++
 arrow/compute/scalar_compare.go                    | 39 +++-------------
 arrow/compute/scalar_compare_test.go               |  4 +-
 3 files changed, 61 insertions(+), 34 deletions(-)

diff --git a/arrow/compute/internal/kernels/scalar_comparisons.go 
b/arrow/compute/internal/kernels/scalar_comparisons.go
index b30605b..e4a5054 100644
--- a/arrow/compute/internal/kernels/scalar_comparisons.go
+++ b/arrow/compute/internal/kernels/scalar_comparisons.go
@@ -745,3 +745,55 @@ func IsNullNotNullKernels() []exec.ScalarKernel {
 
        return results
 }
+
+func ConstBoolExec(val bool) func(*exec.KernelCtx, *exec.ExecSpan, 
*exec.ExecResult) error {
+       return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out 
*exec.ExecResult) error {
+               bitutil.SetBitsTo(out.Buffers[1].Buf, out.Offset, batch.Len, 
val)
+               return nil
+       }
+}
+
+func isNanKernelExec[T float32 | float64](ctx *exec.KernelCtx, batch 
*exec.ExecSpan, out *exec.ExecResult) error {
+       kn := ctx.Kernel.(*exec.ScalarKernel)
+       knData := kn.Data.(CompareFuncData).Funcs()
+
+       outPrefix := int(out.Offset % 8)
+       outBuf := out.Buffers[1].Buf[out.Offset/8:]
+
+       inputBytes := getOffsetSpanBytes(&batch.Values[0].Array)
+       knData.funcAA(inputBytes, inputBytes, outBuf, outPrefix)
+       return nil
+}
+
+func IsNaNKernels() []exec.ScalarKernel {
+       outputType := exec.NewOutputType(arrow.FixedWidthTypes.Boolean)
+
+       knFloat32 := 
exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float32)},
+               outputType, isNanKernelExec[float32], nil)
+       knFloat32.Data = genCompareKernel[float32](CmpNE)
+       knFloat32.NullHandling = exec.NullNoOutput
+       knFloat64 := 
exec.NewScalarKernel([]exec.InputType{exec.NewExactInput(arrow.PrimitiveTypes.Float64)},
+               outputType, isNanKernelExec[float64], nil)
+       knFloat64.Data = genCompareKernel[float64](CmpNE)
+       knFloat64.NullHandling = exec.NullNoOutput
+
+       kernels := []exec.ScalarKernel{knFloat32, knFloat64}
+
+       for _, dt := range intTypes {
+               kn := exec.NewScalarKernel(
+                       []exec.InputType{exec.NewExactInput(dt)},
+                       outputType, ConstBoolExec(false), nil)
+               kn.NullHandling = exec.NullNoOutput
+               kernels = append(kernels, kn)
+       }
+
+       for _, id := range []arrow.Type{arrow.NULL, arrow.DURATION, 
arrow.DECIMAL32, arrow.DECIMAL64, arrow.DECIMAL128, arrow.DECIMAL256} {
+               kn := exec.NewScalarKernel(
+                       []exec.InputType{exec.NewIDInput(id)},
+                       outputType, ConstBoolExec(false), nil)
+               kn.NullHandling = exec.NullNoOutput
+               kernels = append(kernels, kn)
+       }
+
+       return kernels
+}
diff --git a/arrow/compute/scalar_compare.go b/arrow/compute/scalar_compare.go
index cfead2a..0e853a6 100644
--- a/arrow/compute/scalar_compare.go
+++ b/arrow/compute/scalar_compare.go
@@ -20,12 +20,10 @@ package compute
 
 import (
        "context"
-       "fmt"
 
        "github.com/apache/arrow-go/v18/arrow"
        "github.com/apache/arrow-go/v18/arrow/compute/exec"
        "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
-       "github.com/apache/arrow-go/v18/arrow/scalar"
 )
 
 type compareFunction struct {
@@ -152,34 +150,11 @@ func RegisterScalarComparisons(reg FunctionRegistry) {
        reg.AddFunction(isNullFn, false)
        reg.AddFunction(isNotNullFn, false)
 
-       reg.AddFunction(NewMetaFunction("is_nan", Unary(), EmptyFuncDoc,
-               func(ctx context.Context, opts FunctionOptions, args ...Datum) 
(Datum, error) {
-                       type hasType interface {
-                               Type() arrow.DataType
-                       }
-
-                       // only Scalar, Array and ChunkedArray have a Type 
method
-                       arg, ok := args[0].(hasType)
-                       if !ok {
-                               // don't support Table/Record/None kinds
-                               return nil, fmt.Errorf("%w: unsupported type 
for is_nan %s",
-                                       arrow.ErrNotImplemented, args[0])
-                       }
-
-                       switch arg.Type() {
-                       case arrow.PrimitiveTypes.Float32, 
arrow.PrimitiveTypes.Float64:
-                               return CallFunction(ctx, "not_equal", nil, 
args[0], args[0])
-                       default:
-                               if arg, ok := args[0].(ArrayLikeDatum); ok {
-                                       result, err := 
scalar.MakeArrayFromScalar(scalar.NewBooleanScalar(false),
-                                               int(arg.Len()), 
GetAllocator(ctx))
-                                       if err != nil {
-                                               return nil, err
-                                       }
-                                       return NewDatumWithoutOwning(result), 
nil
-                               }
-
-                               return NewDatum(false), nil
-                       }
-               }), false)
+       isNaNFn := &compareFunction{*NewScalarFunction("is_nan", Unary(), 
EmptyFuncDoc)}
+       for _, k := range kernels.IsNaNKernels() {
+               if err := isNaNFn.AddKernel(k); err != nil {
+                       panic(err)
+               }
+       }
+       reg.AddFunction(isNaNFn, false)
 }
diff --git a/arrow/compute/scalar_compare_test.go 
b/arrow/compute/scalar_compare_test.go
index b0c9ab9..e45b3af 100644
--- a/arrow/compute/scalar_compare_test.go
+++ b/arrow/compute/scalar_compare_test.go
@@ -1497,8 +1497,8 @@ func (sv *ScalarValiditySuite) TestIsNaN() {
        }{
                {`[]`, `[]`},
                {`[1]`, `[false]`},
-               {`[null]`, `[null]`},
-               {`["NaN", 1, 0, null]`, `[true, false, false, null]`},
+               {`[null]`, `[false]`},
+               {`["NaN", 1, 0, null]`, `[true, false, false, false]`},
        }
 
        for _, typ := range floatingTypes {

Reply via email to