This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 8598fb3  feat(arrow/compute): support some float16 casts (#430)
8598fb3 is described below

commit 8598fb3bd433b713424c4390fc9b785fc33fa5ee
Author: Matt Topol <zotthewiz...@gmail.com>
AuthorDate: Mon Jul 7 15:24:33 2025 -0400

    feat(arrow/compute): support some float16 casts (#430)
    
    closes #424
    
    ### Rationale for this change
    Support casting float16 arrays to/from int and float32/float64
    
    ### What changes are included in this PR?
    Implementation of new casting kernels for cast_float, cast_half_float,
    cast_int32
    
    ### Are these changes tested?
    Unit tests are added to account for this.
    
    ### Are there any user-facing changes?
    The only user-facing change is that float16.Num is no longer a struct
    with bits member but instead a type definition.
    
    As a result, using `float16.Num{}` must be replaced with
    `float16.Num(0)` if it is used.
---
 arrow/compute/cast_test.go                     |   8 +-
 arrow/compute/internal/kernels/cast_numeric.go |  46 ++++++++-
 arrow/compute/internal/kernels/helpers.go      |   6 +-
 arrow/compute/internal/kernels/numeric_cast.go | 133 ++++++++++++++++++++++---
 arrow/float16/float16.go                       |   7 +-
 arrow/float16/float16_test.go                  |   2 +-
 6 files changed, 181 insertions(+), 21 deletions(-)

diff --git a/arrow/compute/cast_test.go b/arrow/compute/cast_test.go
index 4e5f0a5..370ced1 100644
--- a/arrow/compute/cast_test.go
+++ b/arrow/compute/cast_test.go
@@ -574,7 +574,7 @@ func (c *CastSuite) TestToIntDowncastUnsafe() {
 }
 
 func (c *CastSuite) TestFloatingToInt() {
-       for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32, 
arrow.PrimitiveTypes.Float64} {
+       for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32, 
arrow.PrimitiveTypes.Float64, arrow.FixedWidthTypes.Float16} {
                for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Int32, 
arrow.PrimitiveTypes.Int64} {
                        // float to int no truncation
                        c.checkCast(from, to, `[1.0, null, 0.0, -1.0, 5.0]`, 
`[1, null, 0, -1, 5]`)
@@ -590,6 +590,12 @@ func (c *CastSuite) TestFloatingToInt() {
        }
 }
 
+func (c *CastSuite) TestFloat16ToFloating() {
+       for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Float32, 
arrow.PrimitiveTypes.Float64} {
+               c.checkCast(arrow.FixedWidthTypes.Float16, to, `[1.5, null, 
0.0, -1.5, 5.5]`, `[1.5, null, 0.0, -1.5, 5.5]`)
+       }
+}
+
 func (c *CastSuite) TestIntToFloating() {
        for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Uint32, 
arrow.PrimitiveTypes.Int32} {
                two24 := `[16777216, 16777217]`
diff --git a/arrow/compute/internal/kernels/cast_numeric.go 
b/arrow/compute/internal/kernels/cast_numeric.go
index a177259..6969d82 100644
--- a/arrow/compute/internal/kernels/cast_numeric.go
+++ b/arrow/compute/internal/kernels/cast_numeric.go
@@ -22,6 +22,7 @@ import (
        "unsafe"
 
        "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/float16"
 )
 
 var castNumericUnsafe func(itype, otype arrow.Type, in, out []byte, len int) = 
castNumericGo
@@ -32,7 +33,19 @@ func DoStaticCast[InT, OutT numeric](in []InT, out []OutT) {
        }
 }
 
-func reinterpret[T numeric](b []byte, len int) (res []T) {
+func DoFloat16Cast[InT numeric](in []InT, out []float16.Num) {
+       for i, v := range in {
+               out[i] = float16.New(float32(v))
+       }
+}
+
+func DoFloat16CastToNumber[OutT numeric](in []float16.Num, out []OutT) {
+       for i, v := range in {
+               out[i] = OutT(v.Float32())
+       }
+}
+
+func reinterpret[T numeric | float16.Num](b []byte, len int) (res []T) {
        return unsafe.Slice((*T)(unsafe.Pointer(&b[0])), len)
 }
 
@@ -54,6 +67,8 @@ func castNumberToNumberUnsafeImpl[T numeric](outT arrow.Type, 
in []T, out []byte
                DoStaticCast(in, reinterpret[int64](out, len(in)))
        case arrow.UINT64:
                DoStaticCast(in, reinterpret[uint64](out, len(in)))
+       case arrow.FLOAT16:
+               DoFloat16Cast(in, reinterpret[float16.Num](out, len(in)))
        case arrow.FLOAT32:
                DoStaticCast(in, reinterpret[float32](out, len(in)))
        case arrow.FLOAT64:
@@ -61,6 +76,33 @@ func castNumberToNumberUnsafeImpl[T numeric](outT 
arrow.Type, in []T, out []byte
        }
 }
 
+func castFloat16ToNumberUnsafeImpl(outT arrow.Type, in []float16.Num, out 
[]byte) {
+       switch outT {
+       case arrow.INT8:
+               DoFloat16CastToNumber(in, reinterpret[int8](out, len(in)))
+       case arrow.UINT8:
+               DoFloat16CastToNumber(in, reinterpret[uint8](out, len(in)))
+       case arrow.INT16:
+               DoFloat16CastToNumber(in, reinterpret[int16](out, len(in)))
+       case arrow.UINT16:
+               DoFloat16CastToNumber(in, reinterpret[uint16](out, len(in)))
+       case arrow.INT32:
+               DoFloat16CastToNumber(in, reinterpret[int32](out, len(in)))
+       case arrow.UINT32:
+               DoFloat16CastToNumber(in, reinterpret[uint32](out, len(in)))
+       case arrow.INT64:
+               DoFloat16CastToNumber(in, reinterpret[int64](out, len(in)))
+       case arrow.UINT64:
+               DoFloat16CastToNumber(in, reinterpret[uint64](out, len(in)))
+       case arrow.FLOAT16:
+               copy(reinterpret[float16.Num](out, len(in)), in)
+       case arrow.FLOAT32:
+               DoFloat16CastToNumber(in, reinterpret[float32](out, len(in)))
+       case arrow.FLOAT64:
+               DoFloat16CastToNumber(in, reinterpret[float64](out, len(in)))
+       }
+}
+
 func castNumericGo(itype, otype arrow.Type, in, out []byte, len int) {
        switch itype {
        case arrow.INT8:
@@ -79,6 +121,8 @@ func castNumericGo(itype, otype arrow.Type, in, out []byte, 
len int) {
                castNumberToNumberUnsafeImpl(otype, reinterpret[int64](in, 
len), out)
        case arrow.UINT64:
                castNumberToNumberUnsafeImpl(otype, reinterpret[uint64](in, 
len), out)
+       case arrow.FLOAT16:
+               castFloat16ToNumberUnsafeImpl(otype, 
reinterpret[float16.Num](in, len), out)
        case arrow.FLOAT32:
                castNumberToNumberUnsafeImpl(otype, reinterpret[float32](in, 
len), out)
        case arrow.FLOAT64:
diff --git a/arrow/compute/internal/kernels/helpers.go 
b/arrow/compute/internal/kernels/helpers.go
index 4a9ead1..ef5f0bb 100644
--- a/arrow/compute/internal/kernels/helpers.go
+++ b/arrow/compute/internal/kernels/helpers.go
@@ -695,7 +695,11 @@ func castNumberToNumberUnsafe(in, out *exec.ArraySpan) {
 
        inputOffset := in.Type.(arrow.FixedWidthDataType).Bytes() * 
int(in.Offset)
        outputOffset := out.Type.(arrow.FixedWidthDataType).Bytes() * 
int(out.Offset)
-       castNumericUnsafe(in.Type.ID(), out.Type.ID(), 
in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
+       if in.Type.ID() == arrow.FLOAT16 || out.Type.ID() == arrow.FLOAT16 {
+               castNumericGo(in.Type.ID(), out.Type.ID(), 
in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
+       } else {
+               castNumericUnsafe(in.Type.ID(), out.Type.ID(), 
in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
+       }
 }
 
 func MaxDecimalDigitsForInt(id arrow.Type) (int32, error) {
diff --git a/arrow/compute/internal/kernels/numeric_cast.go 
b/arrow/compute/internal/kernels/numeric_cast.go
index 1e76709..7681b02 100644
--- a/arrow/compute/internal/kernels/numeric_cast.go
+++ b/arrow/compute/internal/kernels/numeric_cast.go
@@ -28,6 +28,7 @@ import (
        "github.com/apache/arrow-go/v18/arrow/compute/exec"
        "github.com/apache/arrow-go/v18/arrow/decimal128"
        "github.com/apache/arrow-go/v18/arrow/decimal256"
+       "github.com/apache/arrow-go/v18/arrow/float16"
        "github.com/apache/arrow-go/v18/arrow/internal/debug"
        "github.com/apache/arrow-go/v18/internal/bitutils"
        "golang.org/x/exp/constraints"
@@ -506,6 +507,27 @@ func CastFloat64ToDecimal(ctx *exec.KernelCtx, batch 
*exec.ExecSpan, out *exec.E
        return executor(ctx, batch, out)
 }
 
+func CastDecimalToFloat16(ctx *exec.KernelCtx, batch *exec.ExecSpan, out 
*exec.ExecResult) error {
+       var (
+               executor exec.ArrayKernelExec
+       )
+
+       switch dt := batch.Values[0].Array.Type.(type) {
+       case *arrow.Decimal128Type:
+               scale := dt.Scale
+               executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v 
decimal128.Num, err *error) float16.Num {
+                       return float16.New(v.ToFloat32(scale))
+               })
+       case *arrow.Decimal256Type:
+               scale := dt.Scale
+               executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v 
decimal256.Num, err *error) float16.Num {
+                       return float16.New(v.ToFloat32(scale))
+               })
+       }
+
+       return executor(ctx, batch, out)
+}
+
 func CastDecimalToFloating[OutT constraints.Float](ctx *exec.KernelCtx, batch 
*exec.ExecSpan, out *exec.ExecResult) error {
        var (
                executor exec.ArrayKernelExec
@@ -543,13 +565,49 @@ func boolToNum[T numeric](_ *exec.KernelCtx, in []byte, 
out []T) error {
        return nil
 }
 
-func checkFloatTrunc[InT constraints.Float, OutT arrow.IntType | 
arrow.UintType](in, out *exec.ArraySpan) error {
-       wasTrunc := func(out OutT, in InT) bool {
-               return InT(out) != in
+func boolToFloat16(_ *exec.KernelCtx, in []byte, out []float16.Num) error {
+       var (
+               zero float16.Num
+               one  = float16.New(1)
+       )
+
+       for i := range out {
+               if bitutil.BitIsSet(in, i) {
+                       out[i] = one
+               } else {
+                       out[i] = zero
+               }
        }
-       wasTruncMaybeNull := func(out OutT, in InT, isValid bool) bool {
-               return isValid && (InT(out) != in)
+       return nil
+}
+
+func wasTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType | 
arrow.UintType](out OutT, in InT) bool {
+       switch v := any(in).(type) {
+       case float16.Num:
+               return float16.New(float32(out)) != v
+       case float32:
+               return float32(out) != v
+       case float64:
+               return float64(out) != v
+       default:
+               return false
+       }
+}
+
+func wasTruncMaybeNull[InT constraints.Float | float16.Num, OutT arrow.IntType 
| arrow.UintType](out OutT, in InT, isValid bool) bool {
+       switch v := any(in).(type) {
+       case float16.Num:
+               return isValid && (float16.New(float32(out)) != v)
+       case float32:
+               return isValid && (float32(out) != v)
+       case float64:
+               return isValid && (float64(out) != v)
+       default:
+               return false
        }
+}
+
+func checkFloatTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType | 
arrow.UintType](in, out *exec.ArraySpan) error {
        getError := func(val InT) error {
                return fmt.Errorf("%w: float value %f was truncated converting 
to %s",
                        arrow.ErrInvalid, val, out.Type)
@@ -598,7 +656,7 @@ func checkFloatTrunc[InT constraints.Float, OutT 
arrow.IntType | arrow.UintType]
        return nil
 }
 
-func checkFloatToIntTruncImpl[T constraints.Float](in, out *exec.ArraySpan) 
error {
+func checkFloatToIntTruncImpl[T constraints.Float | float16.Num](in, out 
*exec.ArraySpan) error {
        switch out.Type.ID() {
        case arrow.INT8:
                return checkFloatTrunc[T, int8](in, out)
@@ -623,6 +681,8 @@ func checkFloatToIntTruncImpl[T constraints.Float](in, out 
*exec.ArraySpan) erro
 
 func checkFloatToIntTrunc(in, out *exec.ArraySpan) error {
        switch in.Type.ID() {
+       case arrow.FLOAT16:
+               return checkFloatToIntTruncImpl[float16.Num](in, out)
        case arrow.FLOAT32:
                return checkFloatToIntTruncImpl[float32](in, out)
        case arrow.FLOAT64:
@@ -729,6 +789,26 @@ func getParseStringExec[OffsetT int32 | int64](out 
arrow.Type) exec.ArrayKernelE
        panic("invalid type for getParseStringExec")
 }
 
+func addFloat16Casts(outTy arrow.DataType, kernels []exec.ScalarKernel) 
[]exec.ScalarKernel {
+       kernels = append(kernels, GetCommonCastKernels(outTy.ID(), 
exec.NewOutputType(outTy))...)
+
+       kernels = append(kernels, exec.NewScalarKernel(
+               
[]exec.InputType{exec.NewExactInput(arrow.FixedWidthTypes.Boolean)},
+               exec.NewOutputType(outTy), ScalarUnaryBoolArg(boolToFloat16), 
nil))
+
+       for _, inTy := range []arrow.DataType{arrow.BinaryTypes.Binary, 
arrow.BinaryTypes.String} {
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewExactInput(inTy)}, 
exec.NewOutputType(outTy),
+                       getParseStringExec[int32](outTy.ID()), nil))
+       }
+       for _, inTy := range []arrow.DataType{arrow.BinaryTypes.LargeBinary, 
arrow.BinaryTypes.LargeString} {
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewExactInput(inTy)}, 
exec.NewOutputType(outTy),
+                       getParseStringExec[int64](outTy.ID()), nil))
+       }
+       return kernels
+}
+
 func addCommonNumberCasts[T numeric](outTy arrow.DataType, kernels 
[]exec.ScalarKernel) []exec.ScalarKernel {
        kernels = append(kernels, GetCommonCastKernels(outTy.ID(), 
exec.NewOutputType(outTy))...)
 
@@ -759,7 +839,7 @@ func GetCastToInteger[T arrow.IntType | 
arrow.UintType](outType arrow.DataType)
                        CastIntToInt, nil))
        }
 
-       for _, inTy := range floatingTypes {
+       for _, inTy := range append(floatingTypes, 
arrow.FixedWidthTypes.Float16) {
                kernels = append(kernels, exec.NewScalarKernel(
                        []exec.InputType{exec.NewExactInput(inTy)}, output,
                        CastFloatingToInteger, nil))
@@ -775,7 +855,7 @@ func GetCastToInteger[T arrow.IntType | 
arrow.UintType](outType arrow.DataType)
        return kernels
 }
 
-func GetCastToFloating[T constraints.Float](outType arrow.DataType) 
[]exec.ScalarKernel {
+func GetCastToFloating[T constraints.Float | float16.Num](outType 
arrow.DataType) []exec.ScalarKernel {
        kernels := make([]exec.ScalarKernel, 0)
 
        output := exec.NewOutputType(outType)
@@ -785,19 +865,40 @@ func GetCastToFloating[T constraints.Float](outType 
arrow.DataType) []exec.Scala
                        CastIntegerToFloating, nil))
        }
 
-       for _, inTy := range floatingTypes {
+       for _, inTy := range append(floatingTypes, 
arrow.FixedWidthTypes.Float16) {
                kernels = append(kernels, exec.NewScalarKernel(
                        []exec.InputType{exec.NewExactInput(inTy)}, output,
                        CastFloatingToFloating, nil))
        }
 
-       kernels = addCommonNumberCasts[T](outType, kernels)
-       kernels = append(kernels, exec.NewScalarKernel(
-               []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output,
-               CastDecimalToFloating[T], nil))
-       kernels = append(kernels, exec.NewScalarKernel(
-               []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output,
-               CastDecimalToFloating[T], nil))
+       var z T
+       switch any(z).(type) {
+       case float16.Num:
+               kernels = addFloat16Casts(outType, kernels)
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, 
output,
+                       CastDecimalToFloat16, nil))
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, 
output,
+                       CastDecimalToFloat16, nil))
+       case float32:
+               kernels = addCommonNumberCasts[float32](outType, kernels)
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, 
output,
+                       CastDecimalToFloating[float32], nil))
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, 
output,
+                       CastDecimalToFloating[float32], nil))
+       case float64:
+               kernels = addCommonNumberCasts[float64](outType, kernels)
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, 
output,
+                       CastDecimalToFloating[float64], nil))
+               kernels = append(kernels, exec.NewScalarKernel(
+                       []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, 
output,
+                       CastDecimalToFloating[float64], nil))
+       }
+
        return kernels
 }
 
diff --git a/arrow/float16/float16.go b/arrow/float16/float16.go
index 0aa4df8..f6c276b 100644
--- a/arrow/float16/float16.go
+++ b/arrow/float16/float16.go
@@ -18,6 +18,7 @@ package float16
 
 import (
        "encoding/binary"
+       "fmt"
        "math"
        "strconv"
 )
@@ -58,6 +59,10 @@ func New(f float32) Num {
        return Num{bits: (sn << 15) | uint16(res<<10) | fc}
 }
 
+func (f Num) Format(s fmt.State, verb rune) {
+       fmt.Fprintf(s, fmt.FormatString(s, verb), f.Float32())
+}
+
 func (f Num) Float32() float32 {
        sn := uint32((f.bits >> 15) & 0x1)
        exp := (f.bits >> 10) & 0x1f
@@ -179,7 +184,7 @@ func (n Num) IsInf() bool { return (n.bits & 0x7c00) == 
0x7c00 }
 
 func (n Num) IsZero() bool { return (n.bits & 0x7fff) == 0 }
 
-func (f Num) Uint16() uint16 { return f.bits }
+func (f Num) Uint16() uint16 { return uint16(f.bits) }
 func (f Num) String() string { return 
strconv.FormatFloat(float64(f.Float32()), 'g', -1, 32) }
 
 func Inf() Num { return Num{bits: 0x7c00} }
diff --git a/arrow/float16/float16_test.go b/arrow/float16/float16_test.go
index cfde440..9857eda 100644
--- a/arrow/float16/float16_test.go
+++ b/arrow/float16/float16_test.go
@@ -38,7 +38,7 @@ func TestFloat16(t *testing.T) {
                f := k.Float32()
                assert.Equal(t, v, f, "float32 values should be the same")
                i := New(v)
-               assert.Equal(t, k.bits, i.bits, "float16 values should be the 
same")
+               assert.Equal(t, k, i, "float16 values should be the same")
                assert.Equal(t, k.Uint16(), i.Uint16(), "float16 values should 
be the same")
                assert.Equal(t, k.String(), fmt.Sprintf("%v", v), "string 
representation differ")
        }

Reply via email to