This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new 18d6677 fix(arrow/compute): compare kernels with UUID (#174)
18d6677 is described below
commit 18d6677d4440a1a7f702addff36a01e6336d852d
Author: Matt Topol <[email protected]>
AuthorDate: Sat Oct 26 11:27:21 2024 -0400
fix(arrow/compute): compare kernels with UUID (#174)
Split from #171
Enable using the comparison kernels (equal, less, less_equal, greater,
greater_equal) with UUID columns and extension types in general.
Tests are added to check the kernel dispatch and to ensure compute via
substrait works for UUID type scalars.
---
arrow/compute/exprs/exec.go | 22 +++++++++++++++++++++-
arrow/compute/exprs/exec_test.go | 23 +++++++++++++++++++++--
arrow/compute/exprs/extension_types.go | 8 ++++++--
arrow/compute/scalar_compare.go | 1 +
arrow/compute/scalar_compare_test.go | 2 ++
arrow/compute/utils.go | 8 ++++++++
arrow/extensions/uuid.go | 5 +++++
7 files changed, 64 insertions(+), 5 deletions(-)
diff --git a/arrow/compute/exprs/exec.go b/arrow/compute/exprs/exec.go
index 53585f7..8612998 100644
--- a/arrow/compute/exprs/exec.go
+++ b/arrow/compute/exprs/exec.go
@@ -571,6 +571,25 @@ func executeScalarBatch(ctx context.Context, input
compute.ExecBatch, exp expr.E
return nil, err
}
+ var newArgs []compute.Datum
+ // cast arguments if necessary
+ for i, arg := range args {
+ if !arrow.TypeEqual(argTypes[i],
arg.(compute.ArrayLikeDatum).Type()) {
+ if newArgs == nil {
+ newArgs = make([]compute.Datum,
len(args))
+ copy(newArgs, args)
+ }
+ newArgs[i], err = compute.CastDatum(ctx, arg,
compute.SafeCastOptions(argTypes[i]))
+ if err != nil {
+ return nil, err
+ }
+ defer newArgs[i].Release()
+ }
+ }
+ if newArgs != nil {
+ args = newArgs
+ }
+
kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k}
init := k.GetInitFn()
kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes,
Options: opts}
@@ -611,9 +630,10 @@ func executeScalarBatch(ctx context.Context, input
compute.ExecBatch, exp expr.E
if ctx.Err() == context.Canceled && result != nil {
result.Release()
+ result = nil
}
- return result, nil
+ return result, err
}
return nil, arrow.ErrNotImplemented
diff --git a/arrow/compute/exprs/exec_test.go b/arrow/compute/exprs/exec_test.go
index c2a1c27..c02ba63 100644
--- a/arrow/compute/exprs/exec_test.go
+++ b/arrow/compute/exprs/exec_test.go
@@ -27,8 +27,10 @@ import (
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exprs"
+ "github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
+ "github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/expr"
@@ -135,8 +137,16 @@ func TestComparisons(t *testing.T) {
one = scalar.MakeScalar(int32(1))
two = scalar.MakeScalar(int32(2))
- str = scalar.MakeScalar("hello")
- bin = scalar.MakeScalar([]byte("hello"))
+ str = scalar.MakeScalar("hello")
+ bin = scalar.MakeScalar([]byte("hello"))
+ exampleUUID =
uuid.MustParse("102cb62f-e6f8-4eb0-9973-d9b012ff0967")
+ exampleUUID2 =
uuid.MustParse("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b")
+ uuidStorage, _ = scalar.MakeScalarParam(exampleUUID[:],
+ &arrow.FixedSizeBinaryType{ByteWidth: 16})
+ uuidScalar = scalar.NewExtensionScalar(uuidStorage,
extensions.NewUUIDType())
+ uuidStorage2, _ = scalar.MakeScalarParam(exampleUUID2[:],
+ &arrow.FixedSizeBinaryType{ByteWidth: 16})
+ uuidScalar2 = scalar.NewExtensionScalar(uuidStorage2,
extensions.NewUUIDType())
)
getArgType := func(dt arrow.DataType) types.Type {
@@ -147,6 +157,8 @@ func TestComparisons(t *testing.T) {
return &types.StringType{}
case arrow.BINARY:
return &types.BinaryType{}
+ case arrow.EXTENSION:
+ return &types.UUIDType{}
}
panic("wtf")
}
@@ -190,6 +202,13 @@ func TestComparisons(t *testing.T) {
expect(t, "equal", str, bin, true)
expect(t, "equal", bin, str, true)
+
+ expect(t, "equal", uuidScalar, uuidScalar, true)
+ expect(t, "equal", uuidScalar, uuidScalar2, false)
+ expect(t, "less", uuidScalar, uuidScalar2, true)
+ expect(t, "less", uuidScalar2, uuidScalar, false)
+ expect(t, "greater", uuidScalar, uuidScalar2, false)
+ expect(t, "greater", uuidScalar2, uuidScalar, true)
}
func TestExecuteFieldRef(t *testing.T) {
diff --git a/arrow/compute/exprs/extension_types.go
b/arrow/compute/exprs/extension_types.go
index 448c5a4..db780cb 100644
--- a/arrow/compute/exprs/extension_types.go
+++ b/arrow/compute/exprs/extension_types.go
@@ -75,7 +75,7 @@ func (ef *simpleExtensionTypeFactory[P])
ExtensionEquals(other arrow.ExtensionTy
return ef.params == rhs.params
}
func (ef *simpleExtensionTypeFactory[P]) ArrayType() reflect.Type {
- return reflect.TypeOf(array.ExtensionArrayBase{})
+ return reflect.TypeOf(simpleExtensionArrayFactory[P]{})
}
func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType {
@@ -91,10 +91,14 @@ func (ef *simpleExtensionTypeFactory[P]) CreateType(params
P) arrow.DataType {
}
}
+type simpleExtensionArrayFactory[P comparable] struct {
+ array.ExtensionArrayBase
+}
+
type uuidExtParams struct{}
var uuidType = simpleExtensionTypeFactory[uuidExtParams]{
- name: "uuid", getStorage: func(uuidExtParams) arrow.DataType {
+ name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType {
return &arrow.FixedSizeBinaryType{ByteWidth: 16}
}}
diff --git a/arrow/compute/scalar_compare.go b/arrow/compute/scalar_compare.go
index 0b182eb..cfead2a 100644
--- a/arrow/compute/scalar_compare.go
+++ b/arrow/compute/scalar_compare.go
@@ -52,6 +52,7 @@ func (fn *compareFunction) DispatchBest(vals
...arrow.DataType) (exec.Kernel, er
}
ensureDictionaryDecoded(vals...)
+ ensureNoExtensionType(vals...)
replaceNullWithOtherType(vals...)
if dt := commonNumeric(vals...); dt != nil {
diff --git a/arrow/compute/scalar_compare_test.go
b/arrow/compute/scalar_compare_test.go
index ba7e110..b0c9ab9 100644
--- a/arrow/compute/scalar_compare_test.go
+++ b/arrow/compute/scalar_compare_test.go
@@ -30,6 +30,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/compute"
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/compute/internal/kernels"
+ "github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/internal/testing/gen"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
@@ -1289,6 +1290,7 @@ func TestCompareKernelsDispatchBest(t *testing.T) {
&arrow.Decimal128Type{Precision: 3, Scale: 2},
&arrow.Decimal128Type{Precision: 21, Scale: 2}},
{arrow.PrimitiveTypes.Int64, &arrow.Decimal128Type{Precision:
3, Scale: 2},
&arrow.Decimal128Type{Precision: 21, Scale: 2},
&arrow.Decimal128Type{Precision: 3, Scale: 2}},
+ {extensions.NewUUIDType(), extensions.NewUUIDType(),
&arrow.FixedSizeBinaryType{ByteWidth: 16},
&arrow.FixedSizeBinaryType{ByteWidth: 16}},
}
for _, name := range []string{"equal", "not_equal", "less",
"less_equal", "greater", "greater_equal"} {
diff --git a/arrow/compute/utils.go b/arrow/compute/utils.go
index a6e311d..7e4df8d 100644
--- a/arrow/compute/utils.go
+++ b/arrow/compute/utils.go
@@ -105,6 +105,14 @@ func ensureDictionaryDecoded(vals ...arrow.DataType) {
}
}
+func ensureNoExtensionType(vals ...arrow.DataType) {
+ for i, v := range vals {
+ if v.ID() == arrow.EXTENSION {
+ vals[i] = v.(arrow.ExtensionType).StorageType()
+ }
+ }
+}
+
func replaceNullWithOtherType(vals ...arrow.DataType) {
debug.Assert(len(vals) == 2, "should be length 2")
diff --git a/arrow/extensions/uuid.go b/arrow/extensions/uuid.go
index 0c2f175..9aac022 100644
--- a/arrow/extensions/uuid.go
+++ b/arrow/extensions/uuid.go
@@ -228,6 +228,9 @@ func (*UUIDType) ExtensionName() string {
return "arrow.uuid"
}
+func (*UUIDType) Bytes() int { return 16 }
+func (*UUIDType) BitWidth() int { return 128 }
+
func (e *UUIDType) String() string {
return fmt.Sprintf("extension<%s>", e.ExtensionName())
}
@@ -262,4 +265,6 @@ var (
_ array.CustomExtensionBuilder = (*UUIDType)(nil)
_ array.ExtensionArray = (*UUIDArray)(nil)
_ array.Builder = (*UUIDBuilder)(nil)
+
+ _ arrow.FixedWidthDataType = (*UUIDType)(nil)
)