This is an automated email from the ASF dual-hosted git repository. zeroshade pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push: new 8598fb3 feat(arrow/compute): support some float16 casts (#430) 8598fb3 is described below commit 8598fb3bd433b713424c4390fc9b785fc33fa5ee Author: Matt Topol <zotthewiz...@gmail.com> AuthorDate: Mon Jul 7 15:24:33 2025 -0400 feat(arrow/compute): support some float16 casts (#430) closes #424 ### Rationale for this change Support casting float16 arrays to/from int and float32/float64 ### What changes are included in this PR? Implementation of new casting kernels for cast_float, cast_half_float, cast_int32 ### Are these changes tested? Unit tests are added to account for this. ### Are there any user-facing changes? The only user-facing change is that float16.Num is no longer a struct with bits member but instead a type definition. As a result, using `float16.Num{}` must be replaced with `float16.Num(0)` if it is used. --- arrow/compute/cast_test.go | 8 +- arrow/compute/internal/kernels/cast_numeric.go | 46 ++++++++- arrow/compute/internal/kernels/helpers.go | 6 +- arrow/compute/internal/kernels/numeric_cast.go | 133 ++++++++++++++++++++++--- arrow/float16/float16.go | 7 +- arrow/float16/float16_test.go | 2 +- 6 files changed, 181 insertions(+), 21 deletions(-) diff --git a/arrow/compute/cast_test.go b/arrow/compute/cast_test.go index 4e5f0a5..370ced1 100644 --- a/arrow/compute/cast_test.go +++ b/arrow/compute/cast_test.go @@ -574,7 +574,7 @@ func (c *CastSuite) TestToIntDowncastUnsafe() { } func (c *CastSuite) TestFloatingToInt() { - for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64} { + for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64, arrow.FixedWidthTypes.Float16} { for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int64} { // float to int no truncation c.checkCast(from, to, `[1.0, null, 0.0, -1.0, 5.0]`, `[1, null, 0, -1, 5]`) @@ -590,6 +590,12 @@ func (c *CastSuite) TestFloatingToInt() { } } +func (c *CastSuite) TestFloat16ToFloating() { + for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64} { + c.checkCast(arrow.FixedWidthTypes.Float16, to, `[1.5, null, 0.0, -1.5, 5.5]`, `[1.5, null, 0.0, -1.5, 5.5]`) + } +} + func (c *CastSuite) TestIntToFloating() { for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Int32} { two24 := `[16777216, 16777217]` diff --git a/arrow/compute/internal/kernels/cast_numeric.go b/arrow/compute/internal/kernels/cast_numeric.go index a177259..6969d82 100644 --- a/arrow/compute/internal/kernels/cast_numeric.go +++ b/arrow/compute/internal/kernels/cast_numeric.go @@ -22,6 +22,7 @@ import ( "unsafe" "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/float16" ) var castNumericUnsafe func(itype, otype arrow.Type, in, out []byte, len int) = castNumericGo @@ -32,7 +33,19 @@ func DoStaticCast[InT, OutT numeric](in []InT, out []OutT) { } } -func reinterpret[T numeric](b []byte, len int) (res []T) { +func DoFloat16Cast[InT numeric](in []InT, out []float16.Num) { + for i, v := range in { + out[i] = float16.New(float32(v)) + } +} + +func DoFloat16CastToNumber[OutT numeric](in []float16.Num, out []OutT) { + for i, v := range in { + out[i] = OutT(v.Float32()) + } +} + +func reinterpret[T numeric | float16.Num](b []byte, len int) (res []T) { return unsafe.Slice((*T)(unsafe.Pointer(&b[0])), len) } @@ -54,6 +67,8 @@ func castNumberToNumberUnsafeImpl[T numeric](outT arrow.Type, in []T, out []byte DoStaticCast(in, reinterpret[int64](out, len(in))) case arrow.UINT64: DoStaticCast(in, reinterpret[uint64](out, len(in))) + case arrow.FLOAT16: + DoFloat16Cast(in, reinterpret[float16.Num](out, len(in))) case arrow.FLOAT32: DoStaticCast(in, reinterpret[float32](out, len(in))) case arrow.FLOAT64: @@ -61,6 +76,33 @@ func castNumberToNumberUnsafeImpl[T numeric](outT arrow.Type, in []T, out []byte } } +func castFloat16ToNumberUnsafeImpl(outT arrow.Type, in []float16.Num, out []byte) { + switch outT { + case arrow.INT8: + DoFloat16CastToNumber(in, reinterpret[int8](out, len(in))) + case arrow.UINT8: + DoFloat16CastToNumber(in, reinterpret[uint8](out, len(in))) + case arrow.INT16: + DoFloat16CastToNumber(in, reinterpret[int16](out, len(in))) + case arrow.UINT16: + DoFloat16CastToNumber(in, reinterpret[uint16](out, len(in))) + case arrow.INT32: + DoFloat16CastToNumber(in, reinterpret[int32](out, len(in))) + case arrow.UINT32: + DoFloat16CastToNumber(in, reinterpret[uint32](out, len(in))) + case arrow.INT64: + DoFloat16CastToNumber(in, reinterpret[int64](out, len(in))) + case arrow.UINT64: + DoFloat16CastToNumber(in, reinterpret[uint64](out, len(in))) + case arrow.FLOAT16: + copy(reinterpret[float16.Num](out, len(in)), in) + case arrow.FLOAT32: + DoFloat16CastToNumber(in, reinterpret[float32](out, len(in))) + case arrow.FLOAT64: + DoFloat16CastToNumber(in, reinterpret[float64](out, len(in))) + } +} + func castNumericGo(itype, otype arrow.Type, in, out []byte, len int) { switch itype { case arrow.INT8: @@ -79,6 +121,8 @@ func castNumericGo(itype, otype arrow.Type, in, out []byte, len int) { castNumberToNumberUnsafeImpl(otype, reinterpret[int64](in, len), out) case arrow.UINT64: castNumberToNumberUnsafeImpl(otype, reinterpret[uint64](in, len), out) + case arrow.FLOAT16: + castFloat16ToNumberUnsafeImpl(otype, reinterpret[float16.Num](in, len), out) case arrow.FLOAT32: castNumberToNumberUnsafeImpl(otype, reinterpret[float32](in, len), out) case arrow.FLOAT64: diff --git a/arrow/compute/internal/kernels/helpers.go b/arrow/compute/internal/kernels/helpers.go index 4a9ead1..ef5f0bb 100644 --- a/arrow/compute/internal/kernels/helpers.go +++ b/arrow/compute/internal/kernels/helpers.go @@ -695,7 +695,11 @@ func castNumberToNumberUnsafe(in, out *exec.ArraySpan) { inputOffset := in.Type.(arrow.FixedWidthDataType).Bytes() * int(in.Offset) outputOffset := out.Type.(arrow.FixedWidthDataType).Bytes() * int(out.Offset) - castNumericUnsafe(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len)) + if in.Type.ID() == arrow.FLOAT16 || out.Type.ID() == arrow.FLOAT16 { + castNumericGo(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len)) + } else { + castNumericUnsafe(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len)) + } } func MaxDecimalDigitsForInt(id arrow.Type) (int32, error) { diff --git a/arrow/compute/internal/kernels/numeric_cast.go b/arrow/compute/internal/kernels/numeric_cast.go index 1e76709..7681b02 100644 --- a/arrow/compute/internal/kernels/numeric_cast.go +++ b/arrow/compute/internal/kernels/numeric_cast.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/compute/exec" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/float16" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/internal/bitutils" "golang.org/x/exp/constraints" @@ -506,6 +507,27 @@ func CastFloat64ToDecimal(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.E return executor(ctx, batch, out) } +func CastDecimalToFloat16(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + var ( + executor exec.ArrayKernelExec + ) + + switch dt := batch.Values[0].Array.Type.(type) { + case *arrow.Decimal128Type: + scale := dt.Scale + executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v decimal128.Num, err *error) float16.Num { + return float16.New(v.ToFloat32(scale)) + }) + case *arrow.Decimal256Type: + scale := dt.Scale + executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v decimal256.Num, err *error) float16.Num { + return float16.New(v.ToFloat32(scale)) + }) + } + + return executor(ctx, batch, out) +} + func CastDecimalToFloating[OutT constraints.Float](ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { var ( executor exec.ArrayKernelExec @@ -543,13 +565,49 @@ func boolToNum[T numeric](_ *exec.KernelCtx, in []byte, out []T) error { return nil } -func checkFloatTrunc[InT constraints.Float, OutT arrow.IntType | arrow.UintType](in, out *exec.ArraySpan) error { - wasTrunc := func(out OutT, in InT) bool { - return InT(out) != in +func boolToFloat16(_ *exec.KernelCtx, in []byte, out []float16.Num) error { + var ( + zero float16.Num + one = float16.New(1) + ) + + for i := range out { + if bitutil.BitIsSet(in, i) { + out[i] = one + } else { + out[i] = zero + } } - wasTruncMaybeNull := func(out OutT, in InT, isValid bool) bool { - return isValid && (InT(out) != in) + return nil +} + +func wasTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType | arrow.UintType](out OutT, in InT) bool { + switch v := any(in).(type) { + case float16.Num: + return float16.New(float32(out)) != v + case float32: + return float32(out) != v + case float64: + return float64(out) != v + default: + return false + } +} + +func wasTruncMaybeNull[InT constraints.Float | float16.Num, OutT arrow.IntType | arrow.UintType](out OutT, in InT, isValid bool) bool { + switch v := any(in).(type) { + case float16.Num: + return isValid && (float16.New(float32(out)) != v) + case float32: + return isValid && (float32(out) != v) + case float64: + return isValid && (float64(out) != v) + default: + return false } +} + +func checkFloatTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType | arrow.UintType](in, out *exec.ArraySpan) error { getError := func(val InT) error { return fmt.Errorf("%w: float value %f was truncated converting to %s", arrow.ErrInvalid, val, out.Type) @@ -598,7 +656,7 @@ func checkFloatTrunc[InT constraints.Float, OutT arrow.IntType | arrow.UintType] return nil } -func checkFloatToIntTruncImpl[T constraints.Float](in, out *exec.ArraySpan) error { +func checkFloatToIntTruncImpl[T constraints.Float | float16.Num](in, out *exec.ArraySpan) error { switch out.Type.ID() { case arrow.INT8: return checkFloatTrunc[T, int8](in, out) @@ -623,6 +681,8 @@ func checkFloatToIntTruncImpl[T constraints.Float](in, out *exec.ArraySpan) erro func checkFloatToIntTrunc(in, out *exec.ArraySpan) error { switch in.Type.ID() { + case arrow.FLOAT16: + return checkFloatToIntTruncImpl[float16.Num](in, out) case arrow.FLOAT32: return checkFloatToIntTruncImpl[float32](in, out) case arrow.FLOAT64: @@ -729,6 +789,26 @@ func getParseStringExec[OffsetT int32 | int64](out arrow.Type) exec.ArrayKernelE panic("invalid type for getParseStringExec") } +func addFloat16Casts(outTy arrow.DataType, kernels []exec.ScalarKernel) []exec.ScalarKernel { + kernels = append(kernels, GetCommonCastKernels(outTy.ID(), exec.NewOutputType(outTy))...) + + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewExactInput(arrow.FixedWidthTypes.Boolean)}, + exec.NewOutputType(outTy), ScalarUnaryBoolArg(boolToFloat16), nil)) + + for _, inTy := range []arrow.DataType{arrow.BinaryTypes.Binary, arrow.BinaryTypes.String} { + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy), + getParseStringExec[int32](outTy.ID()), nil)) + } + for _, inTy := range []arrow.DataType{arrow.BinaryTypes.LargeBinary, arrow.BinaryTypes.LargeString} { + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy), + getParseStringExec[int64](outTy.ID()), nil)) + } + return kernels +} + func addCommonNumberCasts[T numeric](outTy arrow.DataType, kernels []exec.ScalarKernel) []exec.ScalarKernel { kernels = append(kernels, GetCommonCastKernels(outTy.ID(), exec.NewOutputType(outTy))...) @@ -759,7 +839,7 @@ func GetCastToInteger[T arrow.IntType | arrow.UintType](outType arrow.DataType) CastIntToInt, nil)) } - for _, inTy := range floatingTypes { + for _, inTy := range append(floatingTypes, arrow.FixedWidthTypes.Float16) { kernels = append(kernels, exec.NewScalarKernel( []exec.InputType{exec.NewExactInput(inTy)}, output, CastFloatingToInteger, nil)) @@ -775,7 +855,7 @@ func GetCastToInteger[T arrow.IntType | arrow.UintType](outType arrow.DataType) return kernels } -func GetCastToFloating[T constraints.Float](outType arrow.DataType) []exec.ScalarKernel { +func GetCastToFloating[T constraints.Float | float16.Num](outType arrow.DataType) []exec.ScalarKernel { kernels := make([]exec.ScalarKernel, 0) output := exec.NewOutputType(outType) @@ -785,19 +865,40 @@ func GetCastToFloating[T constraints.Float](outType arrow.DataType) []exec.Scala CastIntegerToFloating, nil)) } - for _, inTy := range floatingTypes { + for _, inTy := range append(floatingTypes, arrow.FixedWidthTypes.Float16) { kernels = append(kernels, exec.NewScalarKernel( []exec.InputType{exec.NewExactInput(inTy)}, output, CastFloatingToFloating, nil)) } - kernels = addCommonNumberCasts[T](outType, kernels) - kernels = append(kernels, exec.NewScalarKernel( - []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output, - CastDecimalToFloating[T], nil)) - kernels = append(kernels, exec.NewScalarKernel( - []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output, - CastDecimalToFloating[T], nil)) + var z T + switch any(z).(type) { + case float16.Num: + kernels = addFloat16Casts(outType, kernels) + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output, + CastDecimalToFloat16, nil)) + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output, + CastDecimalToFloat16, nil)) + case float32: + kernels = addCommonNumberCasts[float32](outType, kernels) + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output, + CastDecimalToFloating[float32], nil)) + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output, + CastDecimalToFloating[float32], nil)) + case float64: + kernels = addCommonNumberCasts[float64](outType, kernels) + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output, + CastDecimalToFloating[float64], nil)) + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output, + CastDecimalToFloating[float64], nil)) + } + return kernels } diff --git a/arrow/float16/float16.go b/arrow/float16/float16.go index 0aa4df8..f6c276b 100644 --- a/arrow/float16/float16.go +++ b/arrow/float16/float16.go @@ -18,6 +18,7 @@ package float16 import ( "encoding/binary" + "fmt" "math" "strconv" ) @@ -58,6 +59,10 @@ func New(f float32) Num { return Num{bits: (sn << 15) | uint16(res<<10) | fc} } +func (f Num) Format(s fmt.State, verb rune) { + fmt.Fprintf(s, fmt.FormatString(s, verb), f.Float32()) +} + func (f Num) Float32() float32 { sn := uint32((f.bits >> 15) & 0x1) exp := (f.bits >> 10) & 0x1f @@ -179,7 +184,7 @@ func (n Num) IsInf() bool { return (n.bits & 0x7c00) == 0x7c00 } func (n Num) IsZero() bool { return (n.bits & 0x7fff) == 0 } -func (f Num) Uint16() uint16 { return f.bits } +func (f Num) Uint16() uint16 { return uint16(f.bits) } func (f Num) String() string { return strconv.FormatFloat(float64(f.Float32()), 'g', -1, 32) } func Inf() Num { return Num{bits: 0x7c00} } diff --git a/arrow/float16/float16_test.go b/arrow/float16/float16_test.go index cfde440..9857eda 100644 --- a/arrow/float16/float16_test.go +++ b/arrow/float16/float16_test.go @@ -38,7 +38,7 @@ func TestFloat16(t *testing.T) { f := k.Float32() assert.Equal(t, v, f, "float32 values should be the same") i := New(v) - assert.Equal(t, k.bits, i.bits, "float16 values should be the same") + assert.Equal(t, k, i, "float16 values should be the same") assert.Equal(t, k.Uint16(), i.Uint16(), "float16 values should be the same") assert.Equal(t, k.String(), fmt.Sprintf("%v", v), "string representation differ") }