This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 72d20ad719 GH-20213: [C++] Implement cast to/from halffloat (#40067)
72d20ad719 is described below
commit 72d20ad719021c5513620e23a0a65fb724f0e299
Author: Clif Houck <[email protected]>
AuthorDate: Thu Apr 4 08:59:30 2024 -0500
GH-20213: [C++] Implement cast to/from halffloat (#40067)
### Rationale for this change
### What changes are included in this PR?
This PR implements casting to and from float16 types using the vendored
float16 library included in arrow at `cpp/arrrow/util/float16.*`.
### Are these changes tested?
Unit tests are included in this PR.
### Are there any user-facing changes?
In that casts to and from float16 will now work, yes.
* Closes: #20213
### TODO
- [x] Add casts to/from float64.
- [x] String <-> float16 casts.
- [x] Integer <-> float16 casts.
- [x] Tests.
- [x] Update
https://github.com/apache/arrow/blob/main/docs/source/status.rst about half
float.
- [x] Rebase.
- [x] Run clang format over this PR.
* GitHub Issue: #20213
Authored-by: Clif Houck <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
c_glib/test/test-half-float-scalar.rb | 2 +-
cpp/src/arrow/compare.cc | 30 ++++++
.../arrow/compute/kernels/scalar_cast_internal.cc | 70 ++++++++++++++
.../arrow/compute/kernels/scalar_cast_numeric.cc | 103 +++++++++++++++++----
.../arrow/compute/kernels/scalar_cast_string.cc | 4 +
cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 25 +++--
cpp/src/arrow/ipc/json_simple.cc | 32 ++++++-
cpp/src/arrow/ipc/json_simple_test.cc | 35 ++++++-
cpp/src/arrow/record_batch_test.cc | 3 +
cpp/src/arrow/type_traits.h | 1 +
cpp/src/arrow/util/formatting.cc | 11 +++
cpp/src/arrow/util/formatting.h | 7 ++
cpp/src/arrow/util/value_parsing.cc | 14 +++
cpp/src/arrow/util/value_parsing.h | 17 ++++
docs/source/status.rst | 11 +--
15 files changed, 325 insertions(+), 40 deletions(-)
diff --git a/c_glib/test/test-half-float-scalar.rb
b/c_glib/test/test-half-float-scalar.rb
index ac41f91ece..3073d84d79 100644
--- a/c_glib/test/test-half-float-scalar.rb
+++ b/c_glib/test/test-half-float-scalar.rb
@@ -41,7 +41,7 @@ class TestHalfFloatScalar < Test::Unit::TestCase
end
def test_to_s
- assert_equal("[\n #{@half_float}\n]", @scalar.to_s)
+ assert_equal("1.0009765625", @scalar.to_s)
end
def test_value
diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index bb632e2eb9..e983b47e39 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -44,6 +44,7 @@
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
+#include "arrow/util/float16.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
@@ -59,6 +60,7 @@ using internal::BitmapReader;
using internal::BitmapUInt64Reader;
using internal::checked_cast;
using internal::OptionalBitmapEquals;
+using util::Float16;
// ----------------------------------------------------------------------
// Public method implementations
@@ -95,6 +97,30 @@ struct FloatingEquality {
const T epsilon;
};
+// For half-float equality.
+template <typename Flags>
+struct FloatingEquality<uint16_t, Flags> {
+ explicit FloatingEquality(const EqualOptions& options)
+ : epsilon(static_cast<float>(options.atol())) {}
+
+ bool operator()(uint16_t x, uint16_t y) const {
+ Float16 f_x = Float16::FromBits(x);
+ Float16 f_y = Float16::FromBits(y);
+ if (x == y) {
+ return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit());
+ }
+ if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) {
+ return true;
+ }
+ if (Flags::approximate && (fabs(f_x.ToFloat() - f_y.ToFloat()) <=
epsilon)) {
+ return true;
+ }
+ return false;
+ }
+
+ const float epsilon;
+};
+
template <typename T, typename Visitor>
struct FloatingEqualityDispatcher {
const EqualOptions& options;
@@ -259,6 +285,8 @@ class RangeDataEqualsImpl {
Status Visit(const DoubleType& type) { return CompareFloating(type); }
+ Status Visit(const HalfFloatType& type) { return CompareFloating(type); }
+
// Also matches StringType
Status Visit(const BinaryType& type) { return CompareBinary(type); }
@@ -863,6 +891,8 @@ class ScalarEqualsVisitor {
Status Visit(const DoubleScalar& left) { return CompareFloating(left); }
+ Status Visit(const HalfFloatScalar& left) { return CompareFloating(left); }
+
template <typename T>
enable_if_t<std::is_base_of<BaseBinaryScalar, T>::value, Status> Visit(const
T& left) {
const auto& right = checked_cast<const BaseBinaryScalar&>(right_);
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
index 8cf5a04add..d8c4088759 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
@@ -19,10 +19,13 @@
#include "arrow/compute/cast_internal.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/extension_type.h"
+#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
+#include "arrow/util/float16.h"
namespace arrow {
+using arrow::util::Float16;
using internal::checked_cast;
using internal::PrimitiveScalarBase;
@@ -47,6 +50,42 @@ struct CastPrimitive {
}
};
+// Converting floating types to half float.
+template <typename InType>
+struct CastPrimitive<HalfFloatType, InType,
enable_if_physical_floating_point<InType>> {
+ static void Exec(const ArraySpan& arr, ArraySpan* out) {
+ using InT = typename InType::c_type;
+ const InT* in_values = arr.GetValues<InT>(1);
+ uint16_t* out_values = out->GetValues<uint16_t>(1);
+ for (int64_t i = 0; i < arr.length; ++i) {
+ *out_values++ = Float16(*in_values++).bits();
+ }
+ }
+};
+
+// Converting from half float to other floating types.
+template <>
+struct CastPrimitive<FloatType, HalfFloatType, enable_if_t<true>> {
+ static void Exec(const ArraySpan& arr, ArraySpan* out) {
+ const uint16_t* in_values = arr.GetValues<uint16_t>(1);
+ float* out_values = out->GetValues<float>(1);
+ for (int64_t i = 0; i < arr.length; ++i) {
+ *out_values++ = Float16::FromBits(*in_values++).ToFloat();
+ }
+ }
+};
+
+template <>
+struct CastPrimitive<DoubleType, HalfFloatType, enable_if_t<true>> {
+ static void Exec(const ArraySpan& arr, ArraySpan* out) {
+ const uint16_t* in_values = arr.GetValues<uint16_t>(1);
+ double* out_values = out->GetValues<double>(1);
+ for (int64_t i = 0; i < arr.length; ++i) {
+ *out_values++ = Float16::FromBits(*in_values++).ToDouble();
+ }
+ }
+};
+
template <typename OutType, typename InType>
struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType,
InType>::value>> {
// memcpy output
@@ -56,6 +95,33 @@ struct CastPrimitive<OutType, InType,
enable_if_t<std::is_same<OutType, InType>:
}
};
+// Cast int to half float
+template <typename InType>
+struct CastPrimitive<HalfFloatType, InType, enable_if_integer<InType>> {
+ static void Exec(const ArraySpan& arr, ArraySpan* out) {
+ using InT = typename InType::c_type;
+ const InT* in_values = arr.GetValues<InT>(1);
+ uint16_t* out_values = out->GetValues<uint16_t>(1);
+ for (int64_t i = 0; i < arr.length; ++i) {
+ float temp = static_cast<float>(*in_values++);
+ *out_values++ = Float16(temp).bits();
+ }
+ }
+};
+
+// Cast half float to int
+template <typename OutType>
+struct CastPrimitive<OutType, HalfFloatType, enable_if_integer<OutType>> {
+ static void Exec(const ArraySpan& arr, ArraySpan* out) {
+ using OutT = typename OutType::c_type;
+ const uint16_t* in_values = arr.GetValues<uint16_t>(1);
+ OutT* out_values = out->GetValues<OutT>(1);
+ for (int64_t i = 0; i < arr.length; ++i) {
+ *out_values++ =
static_cast<OutT>(Float16::FromBits(*in_values++).ToFloat());
+ }
+ }
+};
+
template <typename InType>
void CastNumberImpl(Type::type out_type, const ArraySpan& input, ArraySpan*
out) {
switch (out_type) {
@@ -79,6 +145,8 @@ void CastNumberImpl(Type::type out_type, const ArraySpan&
input, ArraySpan* out)
return CastPrimitive<FloatType, InType>::Exec(input, out);
case Type::DOUBLE:
return CastPrimitive<DoubleType, InType>::Exec(input, out);
+ case Type::HALF_FLOAT:
+ return CastPrimitive<HalfFloatType, InType>::Exec(input, out);
default:
break;
}
@@ -109,6 +177,8 @@ void CastNumberToNumberUnsafe(Type::type in_type,
Type::type out_type,
return CastNumberImpl<FloatType>(out_type, input, out);
case Type::DOUBLE:
return CastNumberImpl<DoubleType>(out_type, input, out);
+ case Type::HALF_FLOAT:
+ return CastNumberImpl<HalfFloatType>(out_type, input, out);
default:
DCHECK(false);
break;
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
index b054e57f04..3df86e7d69 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
@@ -23,6 +23,7 @@
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/scalar.h"
#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/float16.h"
#include "arrow/util/int_util.h"
#include "arrow/util/value_parsing.h"
@@ -34,6 +35,7 @@ using internal::IntegersCanFit;
using internal::OptionalBitBlockCounter;
using internal::ParseValue;
using internal::PrimitiveScalarBase;
+using util::Float16;
namespace compute {
namespace internal {
@@ -56,18 +58,37 @@ Status CastFloatingToFloating(KernelContext*, const
ExecSpan& batch, ExecResult*
// ----------------------------------------------------------------------
// Implement fast safe floating point to integer cast
+//
+template <typename InType, typename OutType, typename InT = typename
InType::c_type,
+ typename OutT = typename OutType::c_type>
+struct WasTruncated {
+ static bool Check(OutT out_val, InT in_val) {
+ return static_cast<InT>(out_val) != in_val;
+ }
+
+ static bool CheckMaybeNull(OutT out_val, InT in_val, bool is_valid) {
+ return is_valid && static_cast<InT>(out_val) != in_val;
+ }
+};
+
+// Half float to int
+template <typename OutType>
+struct WasTruncated<HalfFloatType, OutType> {
+ using OutT = typename OutType::c_type;
+ static bool Check(OutT out_val, uint16_t in_val) {
+ return static_cast<float>(out_val) != Float16::FromBits(in_val).ToFloat();
+ }
+
+ static bool CheckMaybeNull(OutT out_val, uint16_t in_val, bool is_valid) {
+ return is_valid && static_cast<float>(out_val) !=
Float16::FromBits(in_val).ToFloat();
+ }
+};
// InType is a floating point type we are planning to cast to integer
template <typename InType, typename OutType, typename InT = typename
InType::c_type,
typename OutT = typename OutType::c_type>
ARROW_DISABLE_UBSAN("float-cast-overflow")
Status CheckFloatTruncation(const ArraySpan& input, const ArraySpan& output) {
- auto WasTruncated = [&](OutT out_val, InT in_val) -> bool {
- return static_cast<InT>(out_val) != in_val;
- };
- auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) ->
bool {
- return is_valid && static_cast<InT>(out_val) != in_val;
- };
auto GetErrorMessage = [&](InT val) {
return Status::Invalid("Float value ", val, " was truncated converting to
",
*output.type);
@@ -86,26 +107,28 @@ Status CheckFloatTruncation(const ArraySpan& input, const
ArraySpan& output) {
if (block.popcount == block.length) {
// Fast path: branchless
for (int64_t i = 0; i < block.length; ++i) {
- block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]);
+ block_out_of_bounds |=
+ WasTruncated<InType, OutType>::Check(out_data[i], in_data[i]);
}
} else if (block.popcount > 0) {
// Indices have nulls, must only boundscheck non-null values
for (int64_t i = 0; i < block.length; ++i) {
- block_out_of_bounds |= WasTruncatedMaybeNull(
+ block_out_of_bounds |= WasTruncated<InType, OutType>::CheckMaybeNull(
out_data[i], in_data[i], bit_util::GetBit(bitmap, offset_position
+ i));
}
}
if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
if (input.GetNullCount() > 0) {
for (int64_t i = 0; i < block.length; ++i) {
- if (WasTruncatedMaybeNull(out_data[i], in_data[i],
- bit_util::GetBit(bitmap, offset_position +
i))) {
+ if (WasTruncated<InType, OutType>::CheckMaybeNull(
+ out_data[i], in_data[i],
+ bit_util::GetBit(bitmap, offset_position + i))) {
return GetErrorMessage(in_data[i]);
}
}
} else {
for (int64_t i = 0; i < block.length; ++i) {
- if (WasTruncated(out_data[i], in_data[i])) {
+ if (WasTruncated<InType, OutType>::Check(out_data[i], in_data[i])) {
return GetErrorMessage(in_data[i]);
}
}
@@ -151,6 +174,9 @@ Status CheckFloatToIntTruncation(const ExecValue& input,
const ExecResult& outpu
return CheckFloatToIntTruncationImpl<FloatType>(input.array,
*output.array_span());
case Type::DOUBLE:
return CheckFloatToIntTruncationImpl<DoubleType>(input.array,
*output.array_span());
+ case Type::HALF_FLOAT:
+ return CheckFloatToIntTruncationImpl<HalfFloatType>(input.array,
+
*output.array_span());
default:
break;
}
@@ -293,6 +319,15 @@ struct CastFunctor<
}
};
+template <>
+struct CastFunctor<HalfFloatType, StringType, enable_if_t<true>> {
+ static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
out) {
+ return applicator::ScalarUnaryNotNull<HalfFloatType, StringType,
+
ParseString<HalfFloatType>>::Exec(ctx, batch,
+
out);
+ }
+};
+
// ----------------------------------------------------------------------
// Decimal to integer
@@ -689,6 +724,10 @@ std::shared_ptr<CastFunction> GetCastToInteger(std::string
name) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty,
CastFloatingToInteger));
}
+ // Cast from half-float
+ DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)},
out_ty,
+ CastFloatingToInteger));
+
// From other numbers to integer
AddCommonNumberCasts<OutType>(out_ty, func.get());
@@ -715,6 +754,10 @@ std::shared_ptr<CastFunction>
GetCastToFloating(std::string name) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty,
CastFloatingToFloating));
}
+ // From half-float to float/double
+ DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)},
out_ty,
+ CastFloatingToFloating));
+
// From other numbers to floating point
AddCommonNumberCasts<OutType>(out_ty, func.get());
@@ -723,6 +766,7 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string
name) {
CastFunctor<OutType, Decimal128Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)},
out_ty,
CastFunctor<OutType, Decimal256Type>::Exec));
+
return func;
}
@@ -795,6 +839,32 @@ std::shared_ptr<CastFunction> GetCastToDecimal256() {
return func;
}
+std::shared_ptr<CastFunction> GetCastToHalfFloat() {
+ // HalfFloat is a bit brain-damaged for now
+ auto func = std::make_shared<CastFunction>("func", Type::HALF_FLOAT);
+ AddCommonCasts(Type::HALF_FLOAT, float16(), func.get());
+
+ // Casts from integer to floating point
+ for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty},
+ TypeTraits<HalfFloatType>::type_singleton(),
+ CastIntegerToFloating));
+ }
+
+ // Cast from other strings to half float.
+ for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
+ auto exec = GenerateVarBinaryBase<CastFunctor, HalfFloatType>(*in_ty);
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty},
+ TypeTraits<HalfFloatType>::type_singleton(),
exec));
+ }
+
+ DCHECK_OK(func.get()->AddKernel(Type::FLOAT, {InputType(Type::FLOAT)},
float16(),
+ CastFloatingToFloating));
+ DCHECK_OK(func.get()->AddKernel(Type::DOUBLE, {InputType(Type::DOUBLE)},
float16(),
+ CastFloatingToFloating));
+ return func;
+}
+
} // namespace
std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
@@ -830,13 +900,14 @@ std::vector<std::shared_ptr<CastFunction>>
GetNumericCasts() {
functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64"));
// HalfFloat is a bit brain-damaged for now
- auto cast_half_float =
- std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT);
- AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get());
+ auto cast_half_float = GetCastToHalfFloat();
functions.push_back(cast_half_float);
- functions.push_back(GetCastToFloating<FloatType>("cast_float"));
- functions.push_back(GetCastToFloating<DoubleType>("cast_double"));
+ auto cast_float = GetCastToFloating<FloatType>("cast_float");
+ functions.push_back(cast_float);
+
+ auto cast_double = GetCastToFloating<DoubleType>("cast_double");
+ functions.push_back(cast_double);
functions.push_back(GetCastToDecimal128());
functions.push_back(GetCastToDecimal256());
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc
b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc
index a6576e4e4c..3a8352a9b8 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc
@@ -437,6 +437,10 @@ void AddNumberToStringCasts(CastFunction* func) {
GenerateNumeric<NumericToStringCastFunctor,
OutType>(*in_ty),
NullHandling::COMPUTED_NO_PREALLOCATE));
}
+
+ DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {float16()}, out_ty,
+ NumericToStringCastFunctor<OutType,
HalfFloatType>::Exec,
+ NullHandling::COMPUTED_NO_PREALLOCATE));
}
template <typename OutType>
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
index a8acf68f66..af62b4da2c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
@@ -389,7 +389,7 @@ TEST(Cast, ToIntDowncastUnsafe) {
}
TEST(Cast, FloatingToInt) {
- for (auto from : {float32(), float64()}) {
+ for (auto from : {float16(), float32(), float64()}) {
for (auto to : {int32(), int64()}) {
// float to int no truncation
CheckCast(ArrayFromJSON(from, "[1.0, null, 0.0, -1.0, 5.0]"),
@@ -407,6 +407,15 @@ TEST(Cast, FloatingToInt) {
}
}
+TEST(Cast, FloatingToFloating) {
+ for (auto from : {float16(), float32(), float64()}) {
+ for (auto to : {float16(), float32(), float64()}) {
+ CheckCast(ArrayFromJSON(from, "[1.0, 0.0, -1.0, 5.0]"),
+ ArrayFromJSON(to, "[1.0, 0.0, -1.0, 5.0]"));
+ }
+ }
+}
+
TEST(Cast, IntToFloating) {
for (auto from : {uint32(), int32()}) {
std::string two_24 = "[16777216, 16777217]";
@@ -2220,14 +2229,12 @@ TEST(Cast, IntToString) {
}
TEST(Cast, FloatingToString) {
- for (auto string_type : {utf8(), large_utf8()}) {
- CheckCast(
- ArrayFromJSON(float32(), "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]"),
- ArrayFromJSON(string_type, R"(["0", "-0", "1.5", "-inf", "inf", "nan",
null])"));
-
- CheckCast(
- ArrayFromJSON(float64(), "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]"),
- ArrayFromJSON(string_type, R"(["0", "-0", "1.5", "-inf", "inf", "nan",
null])"));
+ for (auto float_type : {float16(), float32(), float64()}) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(ArrayFromJSON(float_type, "[0.0, -0.0, 1.5, -Inf, Inf, NaN,
null]"),
+ ArrayFromJSON(string_type,
+ R"(["0", "-0", "1.5", "-inf", "inf", "nan",
null])"));
+ }
}
}
diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc
index ceeabe0167..9fd449831c 100644
--- a/cpp/src/arrow/ipc/json_simple.cc
+++ b/cpp/src/arrow/ipc/json_simple.cc
@@ -36,6 +36,7 @@
#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/decimal.h"
+#include "arrow/util/float16.h"
#include "arrow/util/logging.h"
#include "arrow/util/value_parsing.h"
@@ -52,6 +53,7 @@ namespace rj = arrow::rapidjson;
namespace arrow {
using internal::ParseValue;
+using util::Float16;
namespace ipc {
namespace internal {
@@ -232,9 +234,9 @@ enable_if_physical_signed_integer<T, Status>
ConvertNumber(const rj::Value& json
// Convert single unsigned integer value
template <typename T>
-enable_if_physical_unsigned_integer<T, Status> ConvertNumber(const rj::Value&
json_obj,
- const DataType&
type,
- typename
T::c_type* out) {
+enable_if_unsigned_integer<T, Status> ConvertNumber(const rj::Value& json_obj,
+ const DataType& type,
+ typename T::c_type* out) {
if (json_obj.IsUint64()) {
uint64_t v64 = json_obj.GetUint64();
*out = static_cast<typename T::c_type>(v64);
@@ -249,6 +251,30 @@ enable_if_physical_unsigned_integer<T, Status>
ConvertNumber(const rj::Value& js
}
}
+// Convert float16/HalfFloatType
+template <typename T>
+enable_if_half_float<T, Status> ConvertNumber(const rj::Value& json_obj,
+ const DataType& type, uint16_t*
out) {
+ if (json_obj.IsDouble()) {
+ double f64 = json_obj.GetDouble();
+ *out = Float16(f64).bits();
+ return Status::OK();
+ } else if (json_obj.IsUint()) {
+ uint32_t u32t = json_obj.GetUint();
+ double f64 = static_cast<double>(u32t);
+ *out = Float16(f64).bits();
+ return Status::OK();
+ } else if (json_obj.IsInt()) {
+ int32_t i32t = json_obj.GetInt();
+ double f64 = static_cast<double>(i32t);
+ *out = Float16(f64).bits();
+ return Status::OK();
+ } else {
+ *out = static_cast<uint16_t>(0);
+ return JSONTypeError("unsigned int", json_obj.GetType());
+ }
+}
+
// Convert single floating point value
template <typename T>
enable_if_physical_floating_point<T, Status> ConvertNumber(const rj::Value&
json_obj,
diff --git a/cpp/src/arrow/ipc/json_simple_test.cc
b/cpp/src/arrow/ipc/json_simple_test.cc
index ea3a9ae1a1..b3f7fc5b34 100644
--- a/cpp/src/arrow/ipc/json_simple_test.cc
+++ b/cpp/src/arrow/ipc/json_simple_test.cc
@@ -44,6 +44,7 @@
#include "arrow/util/bitmap_builders.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/decimal.h"
+#include "arrow/util/float16.h"
#if defined(_MSC_VER)
// "warning C4307: '+': integral constant overflow"
@@ -51,6 +52,9 @@
#endif
namespace arrow {
+
+using util::Float16;
+
namespace ipc {
namespace internal {
namespace json {
@@ -185,6 +189,21 @@ class TestIntegers : public ::testing::Test {
TYPED_TEST_SUITE_P(TestIntegers);
+template <typename DataType>
+std::vector<typename DataType::c_type> TestIntegersMutateIfNeeded(
+ std::vector<typename DataType::c_type> data) {
+ return data;
+}
+
+// TODO: This works, but is it the right way to do this?
+template <>
+std::vector<HalfFloatType::c_type> TestIntegersMutateIfNeeded<HalfFloatType>(
+ std::vector<HalfFloatType::c_type> data) {
+ std::for_each(data.begin(), data.end(),
+ [](HalfFloatType::c_type& value) { value =
Float16(value).bits(); });
+ return data;
+}
+
TYPED_TEST_P(TestIntegers, Basics) {
using T = TypeParam;
using c_type = typename T::c_type;
@@ -193,16 +212,17 @@ TYPED_TEST_P(TestIntegers, Basics) {
auto type = this->type();
AssertJSONArray<T>(type, "[]", {});
- AssertJSONArray<T>(type, "[4, 0, 5]", {4, 0, 5});
- AssertJSONArray<T>(type, "[4, null, 5]", {true, false, true}, {4, 0, 5});
+ AssertJSONArray<T>(type, "[4, 0, 5]", TestIntegersMutateIfNeeded<T>({4, 0,
5}));
+ AssertJSONArray<T>(type, "[4, null, 5]", {true, false, true},
+ TestIntegersMutateIfNeeded<T>({4, 0, 5}));
// Test limits
const auto min_val = std::numeric_limits<c_type>::min();
const auto max_val = std::numeric_limits<c_type>::max();
std::string json_string = JSONArray(0, 1, min_val);
- AssertJSONArray<T>(type, json_string, {0, 1, min_val});
+ AssertJSONArray<T>(type, json_string, TestIntegersMutateIfNeeded<T>({0, 1,
min_val}));
json_string = JSONArray(0, 1, max_val);
- AssertJSONArray<T>(type, json_string, {0, 1, max_val});
+ AssertJSONArray<T>(type, json_string, TestIntegersMutateIfNeeded<T>({0, 1,
max_val}));
}
TYPED_TEST_P(TestIntegers, Errors) {
@@ -269,7 +289,12 @@ INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestIntegers,
UInt8Type);
INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestIntegers, UInt16Type);
INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestIntegers, UInt32Type);
INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt64, TestIntegers, UInt64Type);
-INSTANTIATE_TYPED_TEST_SUITE_P(TestHalfFloat, TestIntegers, HalfFloatType);
+// FIXME: I understand that HalfFloatType is backed by a uint16_t, but does it
+// make sense to run this test over it?
+// The way ConvertNumber for HalfFloatType is currently written, it allows the
+// conversion of floating point notation to a half float, which causes failures
+// in this test, one example is asserting 0.0 cannot be parsed as a half float.
+// INSTANTIATE_TYPED_TEST_SUITE_P(TestHalfFloat, TestIntegers, HalfFloatType);
template <typename T>
class TestStrings : public ::testing::Test {
diff --git a/cpp/src/arrow/record_batch_test.cc
b/cpp/src/arrow/record_batch_test.cc
index 7e0eb1d460..95f601465b 100644
--- a/cpp/src/arrow/record_batch_test.cc
+++ b/cpp/src/arrow/record_batch_test.cc
@@ -36,11 +36,14 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/type.h"
+#include "arrow/util/float16.h"
#include "arrow/util/iterator.h"
#include "arrow/util/key_value_metadata.h"
namespace arrow {
+using util::Float16;
+
class TestRecordBatch : public ::testing::Test {};
TEST_F(TestRecordBatch, Equals) {
diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h
index ed66c9367d..8caf4400fe 100644
--- a/cpp/src/arrow/type_traits.h
+++ b/cpp/src/arrow/type_traits.h
@@ -305,6 +305,7 @@ struct TypeTraits<HalfFloatType> {
using BuilderType = HalfFloatBuilder;
using ScalarType = HalfFloatScalar;
using TensorType = HalfFloatTensor;
+ using CType = uint16_t;
static constexpr int64_t bytes_required(int64_t elements) {
return elements * static_cast<int64_t>(sizeof(uint16_t));
diff --git a/cpp/src/arrow/util/formatting.cc b/cpp/src/arrow/util/formatting.cc
index c16d42ce5c..c5a7e03f85 100644
--- a/cpp/src/arrow/util/formatting.cc
+++ b/cpp/src/arrow/util/formatting.cc
@@ -18,10 +18,12 @@
#include "arrow/util/formatting.h"
#include "arrow/util/config.h"
#include "arrow/util/double_conversion.h"
+#include "arrow/util/float16.h"
#include "arrow/util/logging.h"
namespace arrow {
+using util::Float16;
using util::double_conversion::DoubleToStringConverter;
static constexpr int kMinBufferSize =
DoubleToStringConverter::kBase10MaximalLength + 1;
@@ -87,5 +89,14 @@ int FloatToStringFormatter::FormatFloat(double v, char*
out_buffer, int out_size
return builder.position();
}
+int FloatToStringFormatter::FormatFloat(uint16_t v, char* out_buffer, int
out_size) {
+ DCHECK_GE(out_size, kMinBufferSize);
+ util::double_conversion::StringBuilder builder(out_buffer, out_size);
+ bool result = impl_->converter_.ToShortest(Float16::FromBits(v).ToFloat(),
&builder);
+ DCHECK(result);
+ ARROW_UNUSED(result);
+ return builder.position();
+}
+
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/util/formatting.h b/cpp/src/arrow/util/formatting.h
index 71bae74629..6125f792ff 100644
--- a/cpp/src/arrow/util/formatting.h
+++ b/cpp/src/arrow/util/formatting.h
@@ -268,6 +268,7 @@ class ARROW_EXPORT FloatToStringFormatter {
// Returns the number of characters written
int FormatFloat(float v, char* out_buffer, int out_size);
int FormatFloat(double v, char* out_buffer, int out_size);
+ int FormatFloat(uint16_t v, char* out_buffer, int out_size);
protected:
struct Impl;
@@ -301,6 +302,12 @@ class FloatToStringFormatterMixin : public
FloatToStringFormatter {
}
};
+template <>
+class StringFormatter<HalfFloatType> : public
FloatToStringFormatterMixin<HalfFloatType> {
+ public:
+ using FloatToStringFormatterMixin::FloatToStringFormatterMixin;
+};
+
template <>
class StringFormatter<FloatType> : public
FloatToStringFormatterMixin<FloatType> {
public:
diff --git a/cpp/src/arrow/util/value_parsing.cc
b/cpp/src/arrow/util/value_parsing.cc
index f6a24ac146..e84aac995e 100644
--- a/cpp/src/arrow/util/value_parsing.cc
+++ b/cpp/src/arrow/util/value_parsing.cc
@@ -22,8 +22,11 @@
#include <string>
#include <utility>
+#include "arrow/util/float16.h"
#include "arrow/vendored/fast_float/fast_float.h"
+using arrow::util::Float16;
+
namespace arrow {
namespace internal {
@@ -43,6 +46,17 @@ bool StringToFloat(const char* s, size_t length, char
decimal_point, double* out
return res.ec == std::errc() && res.ptr == s + length;
}
+// Half float
+bool StringToFloat(const char* s, size_t length, char decimal_point, uint16_t*
out) {
+ ::arrow_vendored::fast_float::parse_options options{
+ ::arrow_vendored::fast_float::chars_format::general, decimal_point};
+ float temp_out;
+ const auto res =
+ ::arrow_vendored::fast_float::from_chars_advanced(s, s + length,
temp_out, options);
+ *out = Float16::FromFloat(temp_out).bits();
+ return res.ec == std::errc() && res.ptr == s + length;
+}
+
// ----------------------------------------------------------------------
// strptime-like parsing
diff --git a/cpp/src/arrow/util/value_parsing.h
b/cpp/src/arrow/util/value_parsing.h
index b3c711840f..609906052c 100644
--- a/cpp/src/arrow/util/value_parsing.h
+++ b/cpp/src/arrow/util/value_parsing.h
@@ -135,6 +135,9 @@ bool StringToFloat(const char* s, size_t length, char
decimal_point, float* out)
ARROW_EXPORT
bool StringToFloat(const char* s, size_t length, char decimal_point, double*
out);
+ARROW_EXPORT
+bool StringToFloat(const char* s, size_t length, char decimal_point, uint16_t*
out);
+
template <>
struct StringConverter<FloatType> {
using value_type = float;
@@ -163,6 +166,20 @@ struct StringConverter<DoubleType> {
const char decimal_point;
};
+template <>
+struct StringConverter<HalfFloatType> {
+ using value_type = uint16_t;
+
+ explicit StringConverter(char decimal_point = '.') :
decimal_point(decimal_point) {}
+
+ bool Convert(const HalfFloatType&, const char* s, size_t length, value_type*
out) {
+ return ARROW_PREDICT_TRUE(StringToFloat(s, length, decimal_point, out));
+ }
+
+ private:
+ const char decimal_point;
+};
+
// NOTE: HalfFloatType would require a half<->float conversion library
inline uint8_t ParseDecimalDigit(char c) { return static_cast<uint8_t>(c -
'0'); }
diff --git a/docs/source/status.rst b/docs/source/status.rst
index 9af2fd1921..71d33eaa65 100644
--- a/docs/source/status.rst
+++ b/docs/source/status.rst
@@ -40,7 +40,7 @@ Data Types
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+
| UInt8/16/32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓
| ✓ |
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+
-| Float16 | ✓ (1) | ✓ (2) | ✓ | ✓ | ✓ (3)| ✓ | ✓
| |
+| Float16 | ✓ | ✓ (1) | ✓ | ✓ | ✓ (2)| ✓ | ✓
| |
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+
| Float32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓
| ✓ |
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+
@@ -104,7 +104,7 @@ Data Types
| Data type | C++ | Java | Go | JavaScript | C# | Rust |
Julia | Swift |
| (special) | | | | | | |
| |
+===================+=======+=======+=======+============+=======+=======+=======+=======+
-| Dictionary | ✓ | ✓ (4) | ✓ | ✓ | ✓ | ✓ (3) | ✓
| |
+| Dictionary | ✓ | ✓ (3) | ✓ | ✓ | ✓ | ✓ (3) | ✓
| |
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+
| Extension | ✓ | ✓ | ✓ | | | ✓ | ✓
| |
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+
@@ -113,10 +113,9 @@ Data Types
Notes:
-* \(1) Casting to/from Float16 in C++ is not supported.
-* \(2) Casting to/from Float16 in Java is not supported.
-* \(3) Float16 support in C# is only available when targeting .NET 6+.
-* \(4) Nested dictionaries not supported
+* \(1) Casting to/from Float16 in Java is not supported.
+* \(2) Float16 support in C# is only available when targeting .NET 6+.
+* \(3) Nested dictionaries not supported
.. seealso::
The :ref:`format_columnar` specification.