This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new c72ad1689b ARROW-17587: [Go] Cast From Extension Types (#14016)
c72ad1689b is described below
commit c72ad1689b57b7f48d40d9a353f3337218f5a219
Author: Matt Topol <[email protected]>
AuthorDate: Thu Sep 1 11:14:07 2022 -0400
ARROW-17587: [Go] Cast From Extension Types (#14016)
Authored-by: Matt Topol <[email protected]>
Signed-off-by: Matt Topol <[email protected]>
---
go/arrow/compute/cast.go | 21 ++++++++++++
go/arrow/compute/cast_test.go | 50 ++++++++++++++++++++++++++++
go/arrow/compute/datum.go | 4 +++
go/arrow/compute/internal/exec/span.go | 60 ++++++++++++++++++++++++++++++++++
go/arrow/scalar/scalar.go | 16 +++++----
5 files changed, 144 insertions(+), 7 deletions(-)
diff --git a/go/arrow/compute/cast.go b/go/arrow/compute/cast.go
index 6bdb5d767c..a066bcccb2 100644
--- a/go/arrow/compute/cast.go
+++ b/go/arrow/compute/cast.go
@@ -22,6 +22,7 @@ import (
"sync"
"github.com/apache/arrow/go/v10/arrow"
+ "github.com/apache/arrow/go/v10/arrow/array"
"github.com/apache/arrow/go/v10/arrow/compute/internal/exec"
"github.com/apache/arrow/go/v10/arrow/compute/internal/kernels"
)
@@ -132,8 +133,28 @@ func (cf *castFunction) DispatchExact(vals
...arrow.DataType) (exec.Kernel, erro
return candidates[0], nil
}
+func CastFromExtension(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ opts := ctx.State.(kernels.CastState)
+
+ arr := batch.Values[0].Array.MakeArray().(array.ExtensionArray)
+ defer arr.Release()
+
+ castOpts := CastOptions(opts)
+ result, err := CastArray(ctx.Ctx, arr.Storage(), &castOpts)
+ if err != nil {
+ return err
+ }
+ defer result.Release()
+
+ out.TakeOwnership(result.Data())
+ return nil
+}
+
func addCastFuncs(fn []*castFunction) {
for _, f := range fn {
+ f.AddNewTypeCast(arrow.EXTENSION,
[]exec.InputType{exec.NewIDInput(arrow.EXTENSION)},
+ f.kernels[0].Signature.OutType, CastFromExtension,
+ exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
castTable[f.out] = f
}
}
diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go
index 7c5c422713..ae9b6c76bf 100644
--- a/go/arrow/compute/cast_test.go
+++ b/go/arrow/compute/cast_test.go
@@ -1222,6 +1222,56 @@ func (c *CastSuite) TestIdentityCasts() {
c.checkCastZeroCopy(arrow.FixedWidthTypes.Boolean, `[false, true, null,
false]`)
}
+func (c *CastSuite) smallIntArrayFromJSON(data string) arrow.Array {
+ arr, _, _ := array.FromJSON(c.mem, types.NewSmallintType(),
strings.NewReader(data))
+ return arr
+}
+
+func (c *CastSuite) TestExtensionTypeToIntDowncast() {
+ smallint := types.NewSmallintType()
+ arrow.RegisterExtensionType(smallint)
+ defer arrow.UnregisterExtensionType("smallint")
+
+ c.Run("smallint(int16) to int16", func() {
+ arr := c.smallIntArrayFromJSON(`[0, 100, 200, 1, 2]`)
+ defer arr.Release()
+
+ checkCastZeroCopy(c.T(), arr, arrow.PrimitiveTypes.Int16,
compute.DefaultCastOptions(true))
+
+ c.checkCast(smallint, arrow.PrimitiveTypes.Uint8,
+ `[0, 100, 200, 1, 2]`, `[0, 100, 200, 1, 2]`)
+ })
+
+ c.Run("smallint(int16) to uint8 with overflow", func() {
+ opts := compute.SafeCastOptions(arrow.PrimitiveTypes.Uint8)
+ c.checkCastFails(smallint, `[0, null, 256, 1, 3]`, opts)
+
+ opts.AllowIntOverflow = true
+ c.checkCastOpts(smallint, arrow.PrimitiveTypes.Uint8,
+ `[0, null, 256, 1, 3]`, `[0, null, 0, 1, 3]`, *opts)
+ })
+
+ c.Run("smallint(int16) to uint8 with underflow", func() {
+ opts := compute.SafeCastOptions(arrow.PrimitiveTypes.Uint8)
+ c.checkCastFails(smallint, `[0, null, -1, 1, 3]`, opts)
+
+ opts.AllowIntOverflow = true
+ c.checkCastOpts(smallint, arrow.PrimitiveTypes.Uint8,
+ `[0, null, -1, 1, 3]`, `[0, null, 255, 1, 3]`, *opts)
+ })
+}
+
+func (c *CastSuite) TestNoOutBitmapIfIsAllValid() {
+ a, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int8,
strings.NewReader(`[1]`))
+ defer a.Release()
+
+ opts := compute.SafeCastOptions(arrow.PrimitiveTypes.Int32)
+ result, err := compute.CastArray(context.Background(), a, opts)
+ c.NoError(err)
+ c.NotNil(a.Data().Buffers()[0])
+ c.Nil(result.Data().Buffers()[0])
+}
+
func TestCasts(t *testing.T) {
suite.Run(t, new(CastSuite))
}
diff --git a/go/arrow/compute/datum.go b/go/arrow/compute/datum.go
index 4243344637..b5b88613b4 100644
--- a/go/arrow/compute/datum.go
+++ b/go/arrow/compute/datum.go
@@ -122,6 +122,10 @@ type releasable interface {
}
func (d *ScalarDatum) Release() {
+ if !d.Value.IsValid() {
+ return
+ }
+
if v, ok := d.Value.(releasable); ok {
v.Release()
}
diff --git a/go/arrow/compute/internal/exec/span.go
b/go/arrow/compute/internal/exec/span.go
index d412838be1..c969897eda 100644
--- a/go/arrow/compute/internal/exec/span.go
+++ b/go/arrow/compute/internal/exec/span.go
@@ -392,6 +392,66 @@ func (a *ArraySpan) FillFromScalar(val scalar.Scalar) {
}
}
+// TakeOwnership is like SetMembers only this takes ownership of
+// the buffers by calling Retain on them so that the passed in
+// ArrayData can be released without negatively affecting this
+// ArraySpan
+func (a *ArraySpan) TakeOwnership(data arrow.ArrayData) {
+ a.Type = data.DataType()
+ a.Len = int64(data.Len())
+ if a.Type.ID() == arrow.NULL {
+ a.Nulls = a.Len
+ } else {
+ a.Nulls = int64(data.NullN())
+ }
+ a.Offset = int64(data.Offset())
+
+ for i, b := range data.Buffers() {
+ if b != nil {
+ a.Buffers[i].WrapBuffer(b)
+ b.Retain()
+ } else {
+ a.Buffers[i].Buf = nil
+ a.Buffers[i].Owner = nil
+ a.Buffers[i].SelfAlloc = false
+ }
+ }
+
+ typeID := a.Type.ID()
+ if a.Buffers[0].Buf == nil {
+ switch typeID {
+ case arrow.NULL, arrow.SPARSE_UNION, arrow.DENSE_UNION:
+ default:
+ // should already be zero, but we make sure
+ a.Nulls = 0
+ }
+ }
+
+ for i := len(data.Buffers()); i < 3; i++ {
+ a.Buffers[i].Buf = nil
+ a.Buffers[i].Owner = nil
+ a.Buffers[i].SelfAlloc = false
+ }
+
+ if typeID == arrow.DICTIONARY {
+ if cap(a.Children) >= 1 {
+ a.Children = a.Children[:1]
+ } else {
+ a.Children = make([]ArraySpan, 1)
+ }
+ a.Children[0].TakeOwnership(data.Dictionary())
+ } else {
+ if cap(a.Children) >= len(data.Children()) {
+ a.Children = a.Children[:len(data.Children())]
+ } else {
+ a.Children = make([]ArraySpan, len(data.Children()))
+ }
+ for i, c := range data.Children() {
+ a.Children[i].TakeOwnership(c)
+ }
+ }
+}
+
// SetMembers populates this ArraySpan from the given ArrayData object.
// As this is a non-owning reference, the ArrayData object must not
// be fully released while this ArraySpan is in use, otherwise any buffers
diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go
index a35eb519ba..51f16e36c7 100644
--- a/go/arrow/scalar/scalar.go
+++ b/go/arrow/scalar/scalar.go
@@ -538,13 +538,15 @@ func init() {
}
return
NewDenseUnionScalar(MakeNullScalar(typ.Fields()[0].Type), typ.TypeCodes()[0],
typ)
},
- arrow.DICTIONARY: func(dt arrow.DataType) Scalar { return
NewNullDictScalar(dt) },
- arrow.LARGE_STRING: func(dt arrow.DataType) Scalar { return
&LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
- arrow.LARGE_BINARY: func(dt arrow.DataType) Scalar { return
&LargeBinary{&Binary{scalar: scalar{dt, false}}} },
- arrow.LARGE_LIST: func(dt arrow.DataType) Scalar { return
&LargeList{&List{scalar: scalar{dt, false}}} },
- arrow.DECIMAL256: func(dt arrow.DataType) Scalar { return
&Decimal256{scalar: scalar{dt, false}} },
- arrow.MAP: func(dt arrow.DataType) Scalar { return
&Map{&List{scalar: scalar{dt, false}}} },
- arrow.EXTENSION: func(dt arrow.DataType) Scalar { return
&Extension{scalar: scalar{dt, false}} },
+ arrow.DICTIONARY: func(dt arrow.DataType) Scalar { return
NewNullDictScalar(dt) },
+ arrow.LARGE_STRING: func(dt arrow.DataType) Scalar { return
&LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
+ arrow.LARGE_BINARY: func(dt arrow.DataType) Scalar { return
&LargeBinary{&Binary{scalar: scalar{dt, false}}} },
+ arrow.LARGE_LIST: func(dt arrow.DataType) Scalar { return
&LargeList{&List{scalar: scalar{dt, false}}} },
+ arrow.DECIMAL256: func(dt arrow.DataType) Scalar { return
&Decimal256{scalar: scalar{dt, false}} },
+ arrow.MAP: func(dt arrow.DataType) Scalar { return
&Map{&List{scalar: scalar{dt, false}}} },
+ arrow.EXTENSION: func(dt arrow.DataType) Scalar {
+ return &Extension{scalar: scalar{dt, false}, Value:
MakeNullScalar(dt.(arrow.ExtensionType).StorageType())}
+ },
arrow.FIXED_SIZE_LIST: func(dt arrow.DataType) Scalar { return
&FixedSizeList{&List{scalar: scalar{dt, false}}} },
arrow.DURATION: func(dt arrow.DataType) Scalar { return
&Duration{scalar: scalar{dt, false}} },
// invalid data types to fill out array size 2^6 - 1