This is an automated email from the ASF dual-hosted git repository.
uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 7483331 ARROW-3158: [C++] Handle float truncation during casting
7483331 is described below
commit 7483331548c06fd743ee0b5b96f3abfebaf94beb
Author: Krisztián Szűcs <[email protected]>
AuthorDate: Tue Sep 4 09:43:40 2018 +0200
ARROW-3158: [C++] Handle float truncation during casting
Author: Krisztián Szűcs <[email protected]>
Closes #2503 from kszucs/ARROW-3158 and squashes the following commits:
2ca76bbd <Krisztián Szűcs> check-format
a8841aa2 <Krisztián Szűcs> correct input data type in test case
b0d3ebac <Krisztián Szűcs> set allow_float_truncate true by default
4809bfab <Krisztián Szűcs> allow truncate float option and its
implementation
---
cpp/src/arrow/compute/compute-test.cc | 124 ++++++++++++++++++++--------------
cpp/src/arrow/compute/kernels/cast.cc | 62 +++++++++++++++++
cpp/src/arrow/compute/kernels/cast.h | 15 +++-
3 files changed, 150 insertions(+), 51 deletions(-)
diff --git a/cpp/src/arrow/compute/compute-test.cc
b/cpp/src/arrow/compute/compute-test.cc
index 8bf7d1d..a1dfdef 100644
--- a/cpp/src/arrow/compute/compute-test.cc
+++ b/cpp/src/arrow/compute/compute-test.cc
@@ -184,12 +184,6 @@ TEST_F(TestCast, ToIntUpcast) {
vector<int16_t> e3 = {0, 100, 200, 255, 0};
CheckCase<UInt8Type, uint8_t, Int16Type, int16_t>(uint8(), v3, is_valid,
int16(), e3,
options);
-
- // floating point to integer
- vector<double> v4 = {1.5, 0, 0.5, -1.5, 5.5};
- vector<int32_t> e4 = {1, 0, 0, -1, 5};
- CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, is_valid,
int32(), e4,
- options);
}
TEST_F(TestCast, OverflowInNullSlot) {
@@ -218,32 +212,32 @@ TEST_F(TestCast, ToIntDowncastSafe) {
vector<bool> is_valid = {true, false, true, true, true};
// int16 to uint8, no overflow/underrun
- vector<int16_t> v5 = {0, 100, 200, 1, 2};
- vector<uint8_t> e5 = {0, 100, 200, 1, 2};
- CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid,
uint8(), e5,
+ vector<int16_t> v1 = {0, 100, 200, 1, 2};
+ vector<uint8_t> e1 = {0, 100, 200, 1, 2};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v1, is_valid,
uint8(), e1,
options);
// int16 to uint8, with overflow
- vector<int16_t> v6 = {0, 100, 256, 0, 0};
- CheckFails<Int16Type>(int16(), v6, is_valid, uint8(), options);
+ vector<int16_t> v2 = {0, 100, 256, 0, 0};
+ CheckFails<Int16Type>(int16(), v2, is_valid, uint8(), options);
// underflow
- vector<int16_t> v7 = {0, 100, -1, 0, 0};
- CheckFails<Int16Type>(int16(), v7, is_valid, uint8(), options);
+ vector<int16_t> v3 = {0, 100, -1, 0, 0};
+ CheckFails<Int16Type>(int16(), v3, is_valid, uint8(), options);
// int32 to int16, no overflow
- vector<int32_t> v8 = {0, 1000, 2000, 1, 2};
- vector<int16_t> e8 = {0, 1000, 2000, 1, 2};
- CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid,
int16(), e8,
+ vector<int32_t> v4 = {0, 1000, 2000, 1, 2};
+ vector<int16_t> e4 = {0, 1000, 2000, 1, 2};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v4, is_valid,
int16(), e4,
options);
// int32 to int16, overflow
- vector<int32_t> v9 = {0, 1000, 2000, 70000, 0};
- CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options);
+ vector<int32_t> v5 = {0, 1000, 2000, 70000, 0};
+ CheckFails<Int32Type>(int32(), v5, is_valid, int16(), options);
// underflow
- vector<int32_t> v10 = {0, 1000, 2000, -70000, 0};
- CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options);
+ vector<int32_t> v6 = {0, 1000, 2000, -70000, 0};
+ CheckFails<Int32Type>(int32(), v6, is_valid, int16(), options);
}
TEST_F(TestCast, ToIntDowncastUnsafe) {
@@ -253,41 +247,75 @@ TEST_F(TestCast, ToIntDowncastUnsafe) {
vector<bool> is_valid = {true, false, true, true, true};
// int16 to uint8, no overflow/underrun
- vector<int16_t> v5 = {0, 100, 200, 1, 2};
- vector<uint8_t> e5 = {0, 100, 200, 1, 2};
- CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid,
uint8(), e5,
+ vector<int16_t> v1 = {0, 100, 200, 1, 2};
+ vector<uint8_t> e1 = {0, 100, 200, 1, 2};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v1, is_valid,
uint8(), e1,
options);
// int16 to uint8, with overflow
- vector<int16_t> v6 = {0, 100, 256, 0, 0};
- vector<uint8_t> e6 = {0, 100, 0, 0, 0};
- CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v6, is_valid,
uint8(), e6,
+ vector<int16_t> v2 = {0, 100, 256, 0, 0};
+ vector<uint8_t> e2 = {0, 100, 0, 0, 0};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v2, is_valid,
uint8(), e2,
options);
// underflow
- vector<int16_t> v7 = {0, 100, -1, 0, 0};
- vector<uint8_t> e7 = {0, 100, 255, 0, 0};
- CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v7, is_valid,
uint8(), e7,
+ vector<int16_t> v3 = {0, 100, -1, 0, 0};
+ vector<uint8_t> e3 = {0, 100, 255, 0, 0};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v3, is_valid,
uint8(), e3,
options);
// int32 to int16, no overflow
- vector<int32_t> v8 = {0, 1000, 2000, 1, 2};
- vector<int16_t> e8 = {0, 1000, 2000, 1, 2};
- CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid,
int16(), e8,
+ vector<int32_t> v4 = {0, 1000, 2000, 1, 2};
+ vector<int16_t> e4 = {0, 1000, 2000, 1, 2};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v4, is_valid,
int16(), e4,
options);
// int32 to int16, overflow
// TODO(wesm): do we want to allow this? we could set to null
- vector<int32_t> v9 = {0, 1000, 2000, 70000, 0};
- vector<int16_t> e9 = {0, 1000, 2000, 4464, 0};
- CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v9, is_valid,
int16(), e9,
+ vector<int32_t> v5 = {0, 1000, 2000, 70000, 0};
+ vector<int16_t> e5 = {0, 1000, 2000, 4464, 0};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v5, is_valid,
int16(), e5,
options);
// underflow
// TODO(wesm): do we want to allow this? we could set overflow to null
- vector<int32_t> v10 = {0, 1000, 2000, -70000, 0};
- vector<int16_t> e10 = {0, 1000, 2000, -4464, 0};
- CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v10, is_valid,
int16(), e10,
+ vector<int32_t> v6 = {0, 1000, 2000, -70000, 0};
+ vector<int16_t> e6 = {0, 1000, 2000, -4464, 0};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v6, is_valid,
int16(), e6,
+ options);
+}
+
+TEST_F(TestCast, FloatingPointToInt) {
+ auto options = CastOptions::Safe();
+
+ vector<bool> is_valid = {true, false, true, true, true};
+ vector<bool> all_valid = {true, true, true, true, true};
+
+ // float32 point to integer
+ vector<float> v1 = {1.5, 0, 0.5, -1.5, 5.5};
+ vector<int32_t> e1 = {1, 0, 0, -1, 5};
+ CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, is_valid,
int32(), e1,
+ options);
+ CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, all_valid,
int32(), e1,
+ options);
+
+ // float64 point to integer
+ vector<double> v2 = {1.0, 0, 0.0, -1.0, 5.0};
+ vector<int32_t> e2 = {1, 0, 0, -1, 5};
+ CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, is_valid,
int32(), e2,
+ options);
+ CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, all_valid,
int32(), e2,
+ options);
+
+ vector<double> v3 = {1.5, 0, 0.5, -1.5, 5.5};
+ vector<int32_t> e3 = {1, 0, 0, -1, 5};
+ CheckFails<DoubleType>(float64(), v3, is_valid, int32(), options);
+ CheckFails<DoubleType>(float64(), v3, all_valid, int32(), options);
+
+ options.allow_float_truncate = true;
+ CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, is_valid,
int32(), e3,
+ options);
+ CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, all_valid,
int32(), e3,
options);
}
@@ -982,18 +1010,14 @@ TEST_F(TestCast, ListToList) {
ASSERT_OK(
ListArray::FromArrays(*offsets, *float64_plain_array, pool_,
&float64_list_array));
- this->CheckPass(*int32_list_array, *int64_list_array,
int64_list_array->type(),
- options);
- this->CheckPass(*int32_list_array, *float64_list_array,
float64_list_array->type(),
- options);
- this->CheckPass(*int64_list_array, *int32_list_array,
int32_list_array->type(),
- options);
- this->CheckPass(*int64_list_array, *float64_list_array,
float64_list_array->type(),
- options);
- this->CheckPass(*float64_list_array, *int32_list_array,
int32_list_array->type(),
- options);
- this->CheckPass(*float64_list_array, *int64_list_array,
int64_list_array->type(),
- options);
+ CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(),
options);
+ CheckPass(*int32_list_array, *float64_list_array,
float64_list_array->type(), options);
+ CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(),
options);
+ CheckPass(*int64_list_array, *float64_list_array,
float64_list_array->type(), options);
+
+ options.allow_float_truncate = true;
+ CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(),
options);
+ CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(),
options);
}
// ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/kernels/cast.cc
b/cpp/src/arrow/compute/kernels/cast.cc
index 1101ce7..2a0479d 100644
--- a/cpp/src/arrow/compute/kernels/cast.cc
+++ b/cpp/src/arrow/compute/kernels/cast.cc
@@ -193,6 +193,23 @@ struct is_integer_downcast<
(sizeof(O_T) < sizeof(I_T))));
};
+template <typename O, typename I, typename Enable = void>
+struct is_float_downcast {
+ static constexpr bool value = false;
+};
+
+template <typename O, typename I>
+struct is_float_downcast<
+ O, I,
+ typename std::enable_if<std::is_base_of<Number, O>::value &&
+ std::is_base_of<FloatingPoint, I>::value>::type> {
+ using O_T = typename O::c_type;
+ using I_T = typename I::c_type;
+
+ // Smaller output size
+ static constexpr bool value = !std::is_same<O, I>::value && (sizeof(O_T) <
sizeof(I_T));
+};
+
template <typename O, typename I>
struct CastFunctor<O, I,
typename std::enable_if<std::is_same<BooleanType, O>::value
&&
@@ -253,8 +270,53 @@ struct CastFunctor<O, I,
};
template <typename O, typename I>
+struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O,
I>::value>::type> {
+ void operator()(FunctionContext* ctx, const CastOptions& options,
+ const ArrayData& input, ArrayData* output) {
+ using in_type = typename I::c_type;
+ using out_type = typename O::c_type;
+
+ auto in_offset = input.offset;
+ const in_type* in_data = GetValues<in_type>(input, 1);
+ auto out_data = GetMutableValues<out_type>(output, 1);
+
+ if (options.allow_float_truncate) {
+ // unsafe cast
+ for (int64_t i = 0; i < input.length; ++i) {
+ *out_data++ = static_cast<out_type>(*in_data++);
+ }
+ } else {
+ // safe cast
+ if (input.null_count != 0) {
+ internal::BitmapReader is_valid_reader(input.buffers[0]->data(),
in_offset,
+ input.length);
+ for (int64_t i = 0; i < input.length; ++i) {
+ auto out_value = static_cast<out_type>(*in_data);
+ if (ARROW_PREDICT_FALSE(out_value != *in_data)) {
+ ctx->SetStatus(Status::Invalid("Floating point value truncated"));
+ }
+ *out_data++ = out_value;
+ in_data++;
+ is_valid_reader.Next();
+ }
+ } else {
+ for (int64_t i = 0; i < input.length; ++i) {
+ auto out_value = static_cast<out_type>(*in_data);
+ if (ARROW_PREDICT_FALSE(out_value != *in_data)) {
+ ctx->SetStatus(Status::Invalid("Floating point value truncated"));
+ }
+ *out_data++ = out_value;
+ in_data++;
+ }
+ }
+ }
+ }
+};
+
+template <typename O, typename I>
struct CastFunctor<O, I,
typename std::enable_if<is_numeric_cast<O, I>::value &&
+ !is_float_downcast<O, I>::value &&
!is_integer_downcast<O,
I>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
diff --git a/cpp/src/arrow/compute/kernels/cast.h
b/cpp/src/arrow/compute/kernels/cast.h
index b75bb7b..8392c18 100644
--- a/cpp/src/arrow/compute/kernels/cast.h
+++ b/cpp/src/arrow/compute/kernels/cast.h
@@ -35,10 +35,23 @@ class DataType;
namespace compute {
struct ARROW_EXPORT CastOptions {
- CastOptions() : allow_int_overflow(false), allow_time_truncate(false) {}
+ CastOptions()
+ : allow_int_overflow(false),
+ allow_time_truncate(false),
+ allow_float_truncate(true) {}
+
+ explicit CastOptions(bool safe)
+ : allow_int_overflow(!safe),
+ allow_time_truncate(!safe),
+ allow_float_truncate(!safe) {}
+
+ static CastOptions Safe() { return CastOptions(true); }
+
+ static CastOptions Unsafe() { return CastOptions(false); }
bool allow_int_overflow;
bool allow_time_truncate;
+ bool allow_float_truncate;
};
/// \since 0.7.0