This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new c72ad1689b ARROW-17587: [Go] Cast From Extension Types (#14016)
c72ad1689b is described below

commit c72ad1689b57b7f48d40d9a353f3337218f5a219
Author: Matt Topol <[email protected]>
AuthorDate: Thu Sep 1 11:14:07 2022 -0400

    ARROW-17587: [Go] Cast From Extension Types (#14016)
    
    Authored-by: Matt Topol <[email protected]>
    Signed-off-by: Matt Topol <[email protected]>
---
 go/arrow/compute/cast.go               | 21 ++++++++++++
 go/arrow/compute/cast_test.go          | 50 ++++++++++++++++++++++++++++
 go/arrow/compute/datum.go              |  4 +++
 go/arrow/compute/internal/exec/span.go | 60 ++++++++++++++++++++++++++++++++++
 go/arrow/scalar/scalar.go              | 16 +++++----
 5 files changed, 144 insertions(+), 7 deletions(-)

diff --git a/go/arrow/compute/cast.go b/go/arrow/compute/cast.go
index 6bdb5d767c..a066bcccb2 100644
--- a/go/arrow/compute/cast.go
+++ b/go/arrow/compute/cast.go
@@ -22,6 +22,7 @@ import (
        "sync"
 
        "github.com/apache/arrow/go/v10/arrow"
+       "github.com/apache/arrow/go/v10/arrow/array"
        "github.com/apache/arrow/go/v10/arrow/compute/internal/exec"
        "github.com/apache/arrow/go/v10/arrow/compute/internal/kernels"
 )
@@ -132,8 +133,28 @@ func (cf *castFunction) DispatchExact(vals 
...arrow.DataType) (exec.Kernel, erro
        return candidates[0], nil
 }
 
+func CastFromExtension(ctx *exec.KernelCtx, batch *exec.ExecSpan, out 
*exec.ExecResult) error {
+       opts := ctx.State.(kernels.CastState)
+
+       arr := batch.Values[0].Array.MakeArray().(array.ExtensionArray)
+       defer arr.Release()
+
+       castOpts := CastOptions(opts)
+       result, err := CastArray(ctx.Ctx, arr.Storage(), &castOpts)
+       if err != nil {
+               return err
+       }
+       defer result.Release()
+
+       out.TakeOwnership(result.Data())
+       return nil
+}
+
 func addCastFuncs(fn []*castFunction) {
        for _, f := range fn {
+               f.AddNewTypeCast(arrow.EXTENSION, 
[]exec.InputType{exec.NewIDInput(arrow.EXTENSION)},
+                       f.kernels[0].Signature.OutType, CastFromExtension,
+                       exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
                castTable[f.out] = f
        }
 }
diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go
index 7c5c422713..ae9b6c76bf 100644
--- a/go/arrow/compute/cast_test.go
+++ b/go/arrow/compute/cast_test.go
@@ -1222,6 +1222,56 @@ func (c *CastSuite) TestIdentityCasts() {
        c.checkCastZeroCopy(arrow.FixedWidthTypes.Boolean, `[false, true, null, 
false]`)
 }
 
+func (c *CastSuite) smallIntArrayFromJSON(data string) arrow.Array {
+       arr, _, _ := array.FromJSON(c.mem, types.NewSmallintType(), 
strings.NewReader(data))
+       return arr
+}
+
+func (c *CastSuite) TestExtensionTypeToIntDowncast() {
+       smallint := types.NewSmallintType()
+       arrow.RegisterExtensionType(smallint)
+       defer arrow.UnregisterExtensionType("smallint")
+
+       c.Run("smallint(int16) to int16", func() {
+               arr := c.smallIntArrayFromJSON(`[0, 100, 200, 1, 2]`)
+               defer arr.Release()
+
+               checkCastZeroCopy(c.T(), arr, arrow.PrimitiveTypes.Int16, 
compute.DefaultCastOptions(true))
+
+               c.checkCast(smallint, arrow.PrimitiveTypes.Uint8,
+                       `[0, 100, 200, 1, 2]`, `[0, 100, 200, 1, 2]`)
+       })
+
+       c.Run("smallint(int16) to uint8 with overflow", func() {
+               opts := compute.SafeCastOptions(arrow.PrimitiveTypes.Uint8)
+               c.checkCastFails(smallint, `[0, null, 256, 1, 3]`, opts)
+
+               opts.AllowIntOverflow = true
+               c.checkCastOpts(smallint, arrow.PrimitiveTypes.Uint8,
+                       `[0, null, 256, 1, 3]`, `[0, null, 0, 1, 3]`, *opts)
+       })
+
+       c.Run("smallint(int16) to uint8 with underflow", func() {
+               opts := compute.SafeCastOptions(arrow.PrimitiveTypes.Uint8)
+               c.checkCastFails(smallint, `[0, null, -1, 1, 3]`, opts)
+
+               opts.AllowIntOverflow = true
+               c.checkCastOpts(smallint, arrow.PrimitiveTypes.Uint8,
+                       `[0, null, -1, 1, 3]`, `[0, null, 255, 1, 3]`, *opts)
+       })
+}
+
+func (c *CastSuite) TestNoOutBitmapIfIsAllValid() {
+       a, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int8, 
strings.NewReader(`[1]`))
+       defer a.Release()
+
+       opts := compute.SafeCastOptions(arrow.PrimitiveTypes.Int32)
+       result, err := compute.CastArray(context.Background(), a, opts)
+       c.NoError(err)
+       c.NotNil(a.Data().Buffers()[0])
+       c.Nil(result.Data().Buffers()[0])
+}
+
 func TestCasts(t *testing.T) {
        suite.Run(t, new(CastSuite))
 }
diff --git a/go/arrow/compute/datum.go b/go/arrow/compute/datum.go
index 4243344637..b5b88613b4 100644
--- a/go/arrow/compute/datum.go
+++ b/go/arrow/compute/datum.go
@@ -122,6 +122,10 @@ type releasable interface {
 }
 
 func (d *ScalarDatum) Release() {
+       if !d.Value.IsValid() {
+               return
+       }
+
        if v, ok := d.Value.(releasable); ok {
                v.Release()
        }
diff --git a/go/arrow/compute/internal/exec/span.go 
b/go/arrow/compute/internal/exec/span.go
index d412838be1..c969897eda 100644
--- a/go/arrow/compute/internal/exec/span.go
+++ b/go/arrow/compute/internal/exec/span.go
@@ -392,6 +392,66 @@ func (a *ArraySpan) FillFromScalar(val scalar.Scalar) {
        }
 }
 
+// TakeOwnership is like SetMembers only this takes ownership of
+// the buffers by calling Retain on them so that the passed in
+// ArrayData can be released without negatively affecting this
+// ArraySpan
+func (a *ArraySpan) TakeOwnership(data arrow.ArrayData) {
+       a.Type = data.DataType()
+       a.Len = int64(data.Len())
+       if a.Type.ID() == arrow.NULL {
+               a.Nulls = a.Len
+       } else {
+               a.Nulls = int64(data.NullN())
+       }
+       a.Offset = int64(data.Offset())
+
+       for i, b := range data.Buffers() {
+               if b != nil {
+                       a.Buffers[i].WrapBuffer(b)
+                       b.Retain()
+               } else {
+                       a.Buffers[i].Buf = nil
+                       a.Buffers[i].Owner = nil
+                       a.Buffers[i].SelfAlloc = false
+               }
+       }
+
+       typeID := a.Type.ID()
+       if a.Buffers[0].Buf == nil {
+               switch typeID {
+               case arrow.NULL, arrow.SPARSE_UNION, arrow.DENSE_UNION:
+               default:
+                       // should already be zero, but we make sure
+                       a.Nulls = 0
+               }
+       }
+
+       for i := len(data.Buffers()); i < 3; i++ {
+               a.Buffers[i].Buf = nil
+               a.Buffers[i].Owner = nil
+               a.Buffers[i].SelfAlloc = false
+       }
+
+       if typeID == arrow.DICTIONARY {
+               if cap(a.Children) >= 1 {
+                       a.Children = a.Children[:1]
+               } else {
+                       a.Children = make([]ArraySpan, 1)
+               }
+               a.Children[0].TakeOwnership(data.Dictionary())
+       } else {
+               if cap(a.Children) >= len(data.Children()) {
+                       a.Children = a.Children[:len(data.Children())]
+               } else {
+                       a.Children = make([]ArraySpan, len(data.Children()))
+               }
+               for i, c := range data.Children() {
+                       a.Children[i].TakeOwnership(c)
+               }
+       }
+}
+
 // SetMembers populates this ArraySpan from the given ArrayData object.
 // As this is a non-owning reference, the ArrayData object must not
 // be fully released while this ArraySpan is in use, otherwise any buffers
diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go
index a35eb519ba..51f16e36c7 100644
--- a/go/arrow/scalar/scalar.go
+++ b/go/arrow/scalar/scalar.go
@@ -538,13 +538,15 @@ func init() {
                        }
                        return 
NewDenseUnionScalar(MakeNullScalar(typ.Fields()[0].Type), typ.TypeCodes()[0], 
typ)
                },
-               arrow.DICTIONARY:      func(dt arrow.DataType) Scalar { return 
NewNullDictScalar(dt) },
-               arrow.LARGE_STRING:    func(dt arrow.DataType) Scalar { return 
&LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
-               arrow.LARGE_BINARY:    func(dt arrow.DataType) Scalar { return 
&LargeBinary{&Binary{scalar: scalar{dt, false}}} },
-               arrow.LARGE_LIST:      func(dt arrow.DataType) Scalar { return 
&LargeList{&List{scalar: scalar{dt, false}}} },
-               arrow.DECIMAL256:      func(dt arrow.DataType) Scalar { return 
&Decimal256{scalar: scalar{dt, false}} },
-               arrow.MAP:             func(dt arrow.DataType) Scalar { return 
&Map{&List{scalar: scalar{dt, false}}} },
-               arrow.EXTENSION:       func(dt arrow.DataType) Scalar { return 
&Extension{scalar: scalar{dt, false}} },
+               arrow.DICTIONARY:   func(dt arrow.DataType) Scalar { return 
NewNullDictScalar(dt) },
+               arrow.LARGE_STRING: func(dt arrow.DataType) Scalar { return 
&LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
+               arrow.LARGE_BINARY: func(dt arrow.DataType) Scalar { return 
&LargeBinary{&Binary{scalar: scalar{dt, false}}} },
+               arrow.LARGE_LIST:   func(dt arrow.DataType) Scalar { return 
&LargeList{&List{scalar: scalar{dt, false}}} },
+               arrow.DECIMAL256:   func(dt arrow.DataType) Scalar { return 
&Decimal256{scalar: scalar{dt, false}} },
+               arrow.MAP:          func(dt arrow.DataType) Scalar { return 
&Map{&List{scalar: scalar{dt, false}}} },
+               arrow.EXTENSION: func(dt arrow.DataType) Scalar {
+                       return &Extension{scalar: scalar{dt, false}, Value: 
MakeNullScalar(dt.(arrow.ExtensionType).StorageType())}
+               },
                arrow.FIXED_SIZE_LIST: func(dt arrow.DataType) Scalar { return 
&FixedSizeList{&List{scalar: scalar{dt, false}}} },
                arrow.DURATION:        func(dt arrow.DataType) Scalar { return 
&Duration{scalar: scalar{dt, false}} },
                // invalid data types to fill out array size 2^6 - 1

Reply via email to