This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 18d6677  fix(arrow/compute): compare kernels with UUID (#174)
18d6677 is described below

commit 18d6677d4440a1a7f702addff36a01e6336d852d
Author: Matt Topol <[email protected]>
AuthorDate: Sat Oct 26 11:27:21 2024 -0400

    fix(arrow/compute): compare kernels with UUID (#174)
    
    Split from #171
    
    Enable using the comparison kernels (equal, less, less_equal, greater,
    greater_equal) with UUID columns and extension types in general.
    
    Tests are added to check the kernel dispatch and to ensure compute via
    substrait works for UUID type scalars.
---
 arrow/compute/exprs/exec.go            | 22 +++++++++++++++++++++-
 arrow/compute/exprs/exec_test.go       | 23 +++++++++++++++++++++--
 arrow/compute/exprs/extension_types.go |  8 ++++++--
 arrow/compute/scalar_compare.go        |  1 +
 arrow/compute/scalar_compare_test.go   |  2 ++
 arrow/compute/utils.go                 |  8 ++++++++
 arrow/extensions/uuid.go               |  5 +++++
 7 files changed, 64 insertions(+), 5 deletions(-)

diff --git a/arrow/compute/exprs/exec.go b/arrow/compute/exprs/exec.go
index 53585f7..8612998 100644
--- a/arrow/compute/exprs/exec.go
+++ b/arrow/compute/exprs/exec.go
@@ -571,6 +571,25 @@ func executeScalarBatch(ctx context.Context, input 
compute.ExecBatch, exp expr.E
                        return nil, err
                }
 
+               var newArgs []compute.Datum
+               // cast arguments if necessary
+               for i, arg := range args {
+                       if !arrow.TypeEqual(argTypes[i], 
arg.(compute.ArrayLikeDatum).Type()) {
+                               if newArgs == nil {
+                                       newArgs = make([]compute.Datum, 
len(args))
+                                       copy(newArgs, args)
+                               }
+                               newArgs[i], err = compute.CastDatum(ctx, arg, 
compute.SafeCastOptions(argTypes[i]))
+                               if err != nil {
+                                       return nil, err
+                               }
+                               defer newArgs[i].Release()
+                       }
+               }
+               if newArgs != nil {
+                       args = newArgs
+               }
+
                kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k}
                init := k.GetInitFn()
                kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes, 
Options: opts}
@@ -611,9 +630,10 @@ func executeScalarBatch(ctx context.Context, input 
compute.ExecBatch, exp expr.E
 
                if ctx.Err() == context.Canceled && result != nil {
                        result.Release()
+                       result = nil
                }
 
-               return result, nil
+               return result, err
        }
 
        return nil, arrow.ErrNotImplemented
diff --git a/arrow/compute/exprs/exec_test.go b/arrow/compute/exprs/exec_test.go
index c2a1c27..c02ba63 100644
--- a/arrow/compute/exprs/exec_test.go
+++ b/arrow/compute/exprs/exec_test.go
@@ -27,8 +27,10 @@ import (
        "github.com/apache/arrow-go/v18/arrow/array"
        "github.com/apache/arrow-go/v18/arrow/compute"
        "github.com/apache/arrow-go/v18/arrow/compute/exprs"
+       "github.com/apache/arrow-go/v18/arrow/extensions"
        "github.com/apache/arrow-go/v18/arrow/memory"
        "github.com/apache/arrow-go/v18/arrow/scalar"
+       "github.com/google/uuid"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
        "github.com/substrait-io/substrait-go/expr"
@@ -135,8 +137,16 @@ func TestComparisons(t *testing.T) {
                one  = scalar.MakeScalar(int32(1))
                two  = scalar.MakeScalar(int32(2))
 
-               str = scalar.MakeScalar("hello")
-               bin = scalar.MakeScalar([]byte("hello"))
+               str            = scalar.MakeScalar("hello")
+               bin            = scalar.MakeScalar([]byte("hello"))
+               exampleUUID    = 
uuid.MustParse("102cb62f-e6f8-4eb0-9973-d9b012ff0967")
+               exampleUUID2   = 
uuid.MustParse("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b")
+               uuidStorage, _ = scalar.MakeScalarParam(exampleUUID[:],
+                       &arrow.FixedSizeBinaryType{ByteWidth: 16})
+               uuidScalar      = scalar.NewExtensionScalar(uuidStorage, 
extensions.NewUUIDType())
+               uuidStorage2, _ = scalar.MakeScalarParam(exampleUUID2[:],
+                       &arrow.FixedSizeBinaryType{ByteWidth: 16})
+               uuidScalar2 = scalar.NewExtensionScalar(uuidStorage2, 
extensions.NewUUIDType())
        )
 
        getArgType := func(dt arrow.DataType) types.Type {
@@ -147,6 +157,8 @@ func TestComparisons(t *testing.T) {
                        return &types.StringType{}
                case arrow.BINARY:
                        return &types.BinaryType{}
+               case arrow.EXTENSION:
+                       return &types.UUIDType{}
                }
                panic("wtf")
        }
@@ -190,6 +202,13 @@ func TestComparisons(t *testing.T) {
 
        expect(t, "equal", str, bin, true)
        expect(t, "equal", bin, str, true)
+
+       expect(t, "equal", uuidScalar, uuidScalar, true)
+       expect(t, "equal", uuidScalar, uuidScalar2, false)
+       expect(t, "less", uuidScalar, uuidScalar2, true)
+       expect(t, "less", uuidScalar2, uuidScalar, false)
+       expect(t, "greater", uuidScalar, uuidScalar2, false)
+       expect(t, "greater", uuidScalar2, uuidScalar, true)
 }
 
 func TestExecuteFieldRef(t *testing.T) {
diff --git a/arrow/compute/exprs/extension_types.go 
b/arrow/compute/exprs/extension_types.go
index 448c5a4..db780cb 100644
--- a/arrow/compute/exprs/extension_types.go
+++ b/arrow/compute/exprs/extension_types.go
@@ -75,7 +75,7 @@ func (ef *simpleExtensionTypeFactory[P]) 
ExtensionEquals(other arrow.ExtensionTy
        return ef.params == rhs.params
 }
 func (ef *simpleExtensionTypeFactory[P]) ArrayType() reflect.Type {
-       return reflect.TypeOf(array.ExtensionArrayBase{})
+       return reflect.TypeOf(simpleExtensionArrayFactory[P]{})
 }
 
 func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
@@ -91,10 +91,14 @@ func (ef *simpleExtensionTypeFactory[P]) CreateType(params 
P) arrow.DataType {
        }
 }
 
+type simpleExtensionArrayFactory[P comparable] struct {
+       array.ExtensionArrayBase
+}
+
 type uuidExtParams struct{}
 
 var uuidType = simpleExtensionTypeFactory[uuidExtParams]{
-       name: "uuid", getStorage: func(uuidExtParams) arrow.DataType {
+       name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType {
                return &arrow.FixedSizeBinaryType{ByteWidth: 16}
        }}
 
diff --git a/arrow/compute/scalar_compare.go b/arrow/compute/scalar_compare.go
index 0b182eb..cfead2a 100644
--- a/arrow/compute/scalar_compare.go
+++ b/arrow/compute/scalar_compare.go
@@ -52,6 +52,7 @@ func (fn *compareFunction) DispatchBest(vals 
...arrow.DataType) (exec.Kernel, er
        }
 
        ensureDictionaryDecoded(vals...)
+       ensureNoExtensionType(vals...)
        replaceNullWithOtherType(vals...)
 
        if dt := commonNumeric(vals...); dt != nil {
diff --git a/arrow/compute/scalar_compare_test.go 
b/arrow/compute/scalar_compare_test.go
index ba7e110..b0c9ab9 100644
--- a/arrow/compute/scalar_compare_test.go
+++ b/arrow/compute/scalar_compare_test.go
@@ -30,6 +30,7 @@ import (
        "github.com/apache/arrow-go/v18/arrow/compute"
        "github.com/apache/arrow-go/v18/arrow/compute/exec"
        "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
+       "github.com/apache/arrow-go/v18/arrow/extensions"
        "github.com/apache/arrow-go/v18/arrow/internal/testing/gen"
        "github.com/apache/arrow-go/v18/arrow/memory"
        "github.com/apache/arrow-go/v18/arrow/scalar"
@@ -1289,6 +1290,7 @@ func TestCompareKernelsDispatchBest(t *testing.T) {
                        &arrow.Decimal128Type{Precision: 3, Scale: 2}, 
&arrow.Decimal128Type{Precision: 21, Scale: 2}},
                {arrow.PrimitiveTypes.Int64, &arrow.Decimal128Type{Precision: 
3, Scale: 2},
                        &arrow.Decimal128Type{Precision: 21, Scale: 2}, 
&arrow.Decimal128Type{Precision: 3, Scale: 2}},
+               {extensions.NewUUIDType(), extensions.NewUUIDType(), 
&arrow.FixedSizeBinaryType{ByteWidth: 16}, 
&arrow.FixedSizeBinaryType{ByteWidth: 16}},
        }
 
        for _, name := range []string{"equal", "not_equal", "less", 
"less_equal", "greater", "greater_equal"} {
diff --git a/arrow/compute/utils.go b/arrow/compute/utils.go
index a6e311d..7e4df8d 100644
--- a/arrow/compute/utils.go
+++ b/arrow/compute/utils.go
@@ -105,6 +105,14 @@ func ensureDictionaryDecoded(vals ...arrow.DataType) {
        }
 }
 
+func ensureNoExtensionType(vals ...arrow.DataType) {
+       for i, v := range vals {
+               if v.ID() == arrow.EXTENSION {
+                       vals[i] = v.(arrow.ExtensionType).StorageType()
+               }
+       }
+}
+
 func replaceNullWithOtherType(vals ...arrow.DataType) {
        debug.Assert(len(vals) == 2, "should be length 2")
 
diff --git a/arrow/extensions/uuid.go b/arrow/extensions/uuid.go
index 0c2f175..9aac022 100644
--- a/arrow/extensions/uuid.go
+++ b/arrow/extensions/uuid.go
@@ -228,6 +228,9 @@ func (*UUIDType) ExtensionName() string {
        return "arrow.uuid"
 }
 
+func (*UUIDType) Bytes() int    { return 16 }
+func (*UUIDType) BitWidth() int { return 128 }
+
 func (e *UUIDType) String() string {
        return fmt.Sprintf("extension<%s>", e.ExtensionName())
 }
@@ -262,4 +265,6 @@ var (
        _ array.CustomExtensionBuilder = (*UUIDType)(nil)
        _ array.ExtensionArray         = (*UUIDArray)(nil)
        _ array.Builder                = (*UUIDBuilder)(nil)
+
+       _ arrow.FixedWidthDataType = (*UUIDType)(nil)
 )

Reply via email to