This is an automated email from the ASF dual-hosted git repository.

uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 7483331  ARROW-3158: [C++] Handle float truncation during casting
7483331 is described below

commit 7483331548c06fd743ee0b5b96f3abfebaf94beb
Author: Krisztián Szűcs <[email protected]>
AuthorDate: Tue Sep 4 09:43:40 2018 +0200

    ARROW-3158: [C++] Handle float truncation during casting
    
    Author: Krisztián Szűcs <[email protected]>
    
    Closes #2503 from kszucs/ARROW-3158 and squashes the following commits:
    
    2ca76bbd <Krisztián Szűcs> check-format
    a8841aa2 <Krisztián Szűcs> correct input data type in test case
    b0d3ebac <Krisztián Szűcs> set allow_float_truncate true by default
    4809bfab <Krisztián Szűcs> allow truncate float option and its 
implementation
---
 cpp/src/arrow/compute/compute-test.cc | 124 ++++++++++++++++++++--------------
 cpp/src/arrow/compute/kernels/cast.cc |  62 +++++++++++++++++
 cpp/src/arrow/compute/kernels/cast.h  |  15 +++-
 3 files changed, 150 insertions(+), 51 deletions(-)

diff --git a/cpp/src/arrow/compute/compute-test.cc 
b/cpp/src/arrow/compute/compute-test.cc
index 8bf7d1d..a1dfdef 100644
--- a/cpp/src/arrow/compute/compute-test.cc
+++ b/cpp/src/arrow/compute/compute-test.cc
@@ -184,12 +184,6 @@ TEST_F(TestCast, ToIntUpcast) {
   vector<int16_t> e3 = {0, 100, 200, 255, 0};
   CheckCase<UInt8Type, uint8_t, Int16Type, int16_t>(uint8(), v3, is_valid, 
int16(), e3,
                                                     options);
-
-  // floating point to integer
-  vector<double> v4 = {1.5, 0, 0.5, -1.5, 5.5};
-  vector<int32_t> e4 = {1, 0, 0, -1, 5};
-  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, is_valid, 
int32(), e4,
-                                                    options);
 }
 
 TEST_F(TestCast, OverflowInNullSlot) {
@@ -218,32 +212,32 @@ TEST_F(TestCast, ToIntDowncastSafe) {
   vector<bool> is_valid = {true, false, true, true, true};
 
   // int16 to uint8, no overflow/underrun
-  vector<int16_t> v5 = {0, 100, 200, 1, 2};
-  vector<uint8_t> e5 = {0, 100, 200, 1, 2};
-  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid, 
uint8(), e5,
+  vector<int16_t> v1 = {0, 100, 200, 1, 2};
+  vector<uint8_t> e1 = {0, 100, 200, 1, 2};
+  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v1, is_valid, 
uint8(), e1,
                                                     options);
 
   // int16 to uint8, with overflow
-  vector<int16_t> v6 = {0, 100, 256, 0, 0};
-  CheckFails<Int16Type>(int16(), v6, is_valid, uint8(), options);
+  vector<int16_t> v2 = {0, 100, 256, 0, 0};
+  CheckFails<Int16Type>(int16(), v2, is_valid, uint8(), options);
 
   // underflow
-  vector<int16_t> v7 = {0, 100, -1, 0, 0};
-  CheckFails<Int16Type>(int16(), v7, is_valid, uint8(), options);
+  vector<int16_t> v3 = {0, 100, -1, 0, 0};
+  CheckFails<Int16Type>(int16(), v3, is_valid, uint8(), options);
 
   // int32 to int16, no overflow
-  vector<int32_t> v8 = {0, 1000, 2000, 1, 2};
-  vector<int16_t> e8 = {0, 1000, 2000, 1, 2};
-  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid, 
int16(), e8,
+  vector<int32_t> v4 = {0, 1000, 2000, 1, 2};
+  vector<int16_t> e4 = {0, 1000, 2000, 1, 2};
+  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v4, is_valid, 
int16(), e4,
                                                     options);
 
   // int32 to int16, overflow
-  vector<int32_t> v9 = {0, 1000, 2000, 70000, 0};
-  CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options);
+  vector<int32_t> v5 = {0, 1000, 2000, 70000, 0};
+  CheckFails<Int32Type>(int32(), v5, is_valid, int16(), options);
 
   // underflow
-  vector<int32_t> v10 = {0, 1000, 2000, -70000, 0};
-  CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options);
+  vector<int32_t> v6 = {0, 1000, 2000, -70000, 0};
+  CheckFails<Int32Type>(int32(), v6, is_valid, int16(), options);
 }
 
 TEST_F(TestCast, ToIntDowncastUnsafe) {
@@ -253,41 +247,75 @@ TEST_F(TestCast, ToIntDowncastUnsafe) {
   vector<bool> is_valid = {true, false, true, true, true};
 
   // int16 to uint8, no overflow/underrun
-  vector<int16_t> v5 = {0, 100, 200, 1, 2};
-  vector<uint8_t> e5 = {0, 100, 200, 1, 2};
-  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid, 
uint8(), e5,
+  vector<int16_t> v1 = {0, 100, 200, 1, 2};
+  vector<uint8_t> e1 = {0, 100, 200, 1, 2};
+  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v1, is_valid, 
uint8(), e1,
                                                     options);
 
   // int16 to uint8, with overflow
-  vector<int16_t> v6 = {0, 100, 256, 0, 0};
-  vector<uint8_t> e6 = {0, 100, 0, 0, 0};
-  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v6, is_valid, 
uint8(), e6,
+  vector<int16_t> v2 = {0, 100, 256, 0, 0};
+  vector<uint8_t> e2 = {0, 100, 0, 0, 0};
+  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v2, is_valid, 
uint8(), e2,
                                                     options);
 
   // underflow
-  vector<int16_t> v7 = {0, 100, -1, 0, 0};
-  vector<uint8_t> e7 = {0, 100, 255, 0, 0};
-  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v7, is_valid, 
uint8(), e7,
+  vector<int16_t> v3 = {0, 100, -1, 0, 0};
+  vector<uint8_t> e3 = {0, 100, 255, 0, 0};
+  CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v3, is_valid, 
uint8(), e3,
                                                     options);
 
   // int32 to int16, no overflow
-  vector<int32_t> v8 = {0, 1000, 2000, 1, 2};
-  vector<int16_t> e8 = {0, 1000, 2000, 1, 2};
-  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid, 
int16(), e8,
+  vector<int32_t> v4 = {0, 1000, 2000, 1, 2};
+  vector<int16_t> e4 = {0, 1000, 2000, 1, 2};
+  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v4, is_valid, 
int16(), e4,
                                                     options);
 
   // int32 to int16, overflow
   // TODO(wesm): do we want to allow this? we could set to null
-  vector<int32_t> v9 = {0, 1000, 2000, 70000, 0};
-  vector<int16_t> e9 = {0, 1000, 2000, 4464, 0};
-  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v9, is_valid, 
int16(), e9,
+  vector<int32_t> v5 = {0, 1000, 2000, 70000, 0};
+  vector<int16_t> e5 = {0, 1000, 2000, 4464, 0};
+  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v5, is_valid, 
int16(), e5,
                                                     options);
 
   // underflow
   // TODO(wesm): do we want to allow this? we could set overflow to null
-  vector<int32_t> v10 = {0, 1000, 2000, -70000, 0};
-  vector<int16_t> e10 = {0, 1000, 2000, -4464, 0};
-  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v10, is_valid, 
int16(), e10,
+  vector<int32_t> v6 = {0, 1000, 2000, -70000, 0};
+  vector<int16_t> e6 = {0, 1000, 2000, -4464, 0};
+  CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v6, is_valid, 
int16(), e6,
+                                                    options);
+}
+
+TEST_F(TestCast, FloatingPointToInt) {
+  auto options = CastOptions::Safe();
+
+  vector<bool> is_valid = {true, false, true, true, true};
+  vector<bool> all_valid = {true, true, true, true, true};
+
+  // float32 point to integer
+  vector<float> v1 = {1.5, 0, 0.5, -1.5, 5.5};
+  vector<int32_t> e1 = {1, 0, 0, -1, 5};
+  CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, is_valid, 
int32(), e1,
+                                                  options);
+  CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, all_valid, 
int32(), e1,
+                                                  options);
+
+  // float64 point to integer
+  vector<double> v2 = {1.0, 0, 0.0, -1.0, 5.0};
+  vector<int32_t> e2 = {1, 0, 0, -1, 5};
+  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, is_valid, 
int32(), e2,
+                                                    options);
+  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, all_valid, 
int32(), e2,
+                                                    options);
+
+  vector<double> v3 = {1.5, 0, 0.5, -1.5, 5.5};
+  vector<int32_t> e3 = {1, 0, 0, -1, 5};
+  CheckFails<DoubleType>(float64(), v3, is_valid, int32(), options);
+  CheckFails<DoubleType>(float64(), v3, all_valid, int32(), options);
+
+  options.allow_float_truncate = true;
+  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, is_valid, 
int32(), e3,
+                                                    options);
+  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, all_valid, 
int32(), e3,
                                                     options);
 }
 
@@ -982,18 +1010,14 @@ TEST_F(TestCast, ListToList) {
   ASSERT_OK(
       ListArray::FromArrays(*offsets, *float64_plain_array, pool_, 
&float64_list_array));
 
-  this->CheckPass(*int32_list_array, *int64_list_array, 
int64_list_array->type(),
-                  options);
-  this->CheckPass(*int32_list_array, *float64_list_array, 
float64_list_array->type(),
-                  options);
-  this->CheckPass(*int64_list_array, *int32_list_array, 
int32_list_array->type(),
-                  options);
-  this->CheckPass(*int64_list_array, *float64_list_array, 
float64_list_array->type(),
-                  options);
-  this->CheckPass(*float64_list_array, *int32_list_array, 
int32_list_array->type(),
-                  options);
-  this->CheckPass(*float64_list_array, *int64_list_array, 
int64_list_array->type(),
-                  options);
+  CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(), 
options);
+  CheckPass(*int32_list_array, *float64_list_array, 
float64_list_array->type(), options);
+  CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(), 
options);
+  CheckPass(*int64_list_array, *float64_list_array, 
float64_list_array->type(), options);
+
+  options.allow_float_truncate = true;
+  CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), 
options);
+  CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(), 
options);
 }
 
 // ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/kernels/cast.cc 
b/cpp/src/arrow/compute/kernels/cast.cc
index 1101ce7..2a0479d 100644
--- a/cpp/src/arrow/compute/kernels/cast.cc
+++ b/cpp/src/arrow/compute/kernels/cast.cc
@@ -193,6 +193,23 @@ struct is_integer_downcast<
         (sizeof(O_T) < sizeof(I_T))));
 };
 
+template <typename O, typename I, typename Enable = void>
+struct is_float_downcast {
+  static constexpr bool value = false;
+};
+
+template <typename O, typename I>
+struct is_float_downcast<
+    O, I,
+    typename std::enable_if<std::is_base_of<Number, O>::value &&
+                            std::is_base_of<FloatingPoint, I>::value>::type> {
+  using O_T = typename O::c_type;
+  using I_T = typename I::c_type;
+
+  // Smaller output size
+  static constexpr bool value = !std::is_same<O, I>::value && (sizeof(O_T) < 
sizeof(I_T));
+};
+
 template <typename O, typename I>
 struct CastFunctor<O, I,
                    typename std::enable_if<std::is_same<BooleanType, O>::value 
&&
@@ -253,8 +270,53 @@ struct CastFunctor<O, I,
 };
 
 template <typename O, typename I>
+struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O, 
I>::value>::type> {
+  void operator()(FunctionContext* ctx, const CastOptions& options,
+                  const ArrayData& input, ArrayData* output) {
+    using in_type = typename I::c_type;
+    using out_type = typename O::c_type;
+
+    auto in_offset = input.offset;
+    const in_type* in_data = GetValues<in_type>(input, 1);
+    auto out_data = GetMutableValues<out_type>(output, 1);
+
+    if (options.allow_float_truncate) {
+      // unsafe cast
+      for (int64_t i = 0; i < input.length; ++i) {
+        *out_data++ = static_cast<out_type>(*in_data++);
+      }
+    } else {
+      // safe cast
+      if (input.null_count != 0) {
+        internal::BitmapReader is_valid_reader(input.buffers[0]->data(), 
in_offset,
+                                               input.length);
+        for (int64_t i = 0; i < input.length; ++i) {
+          auto out_value = static_cast<out_type>(*in_data);
+          if (ARROW_PREDICT_FALSE(out_value != *in_data)) {
+            ctx->SetStatus(Status::Invalid("Floating point value truncated"));
+          }
+          *out_data++ = out_value;
+          in_data++;
+          is_valid_reader.Next();
+        }
+      } else {
+        for (int64_t i = 0; i < input.length; ++i) {
+          auto out_value = static_cast<out_type>(*in_data);
+          if (ARROW_PREDICT_FALSE(out_value != *in_data)) {
+            ctx->SetStatus(Status::Invalid("Floating point value truncated"));
+          }
+          *out_data++ = out_value;
+          in_data++;
+        }
+      }
+    }
+  }
+};
+
+template <typename O, typename I>
 struct CastFunctor<O, I,
                    typename std::enable_if<is_numeric_cast<O, I>::value &&
+                                           !is_float_downcast<O, I>::value &&
                                            !is_integer_downcast<O, 
I>::value>::type> {
   void operator()(FunctionContext* ctx, const CastOptions& options,
                   const ArrayData& input, ArrayData* output) {
diff --git a/cpp/src/arrow/compute/kernels/cast.h 
b/cpp/src/arrow/compute/kernels/cast.h
index b75bb7b..8392c18 100644
--- a/cpp/src/arrow/compute/kernels/cast.h
+++ b/cpp/src/arrow/compute/kernels/cast.h
@@ -35,10 +35,23 @@ class DataType;
 namespace compute {
 
 struct ARROW_EXPORT CastOptions {
-  CastOptions() : allow_int_overflow(false), allow_time_truncate(false) {}
+  CastOptions()
+      : allow_int_overflow(false),
+        allow_time_truncate(false),
+        allow_float_truncate(true) {}
+
+  explicit CastOptions(bool safe)
+      : allow_int_overflow(!safe),
+        allow_time_truncate(!safe),
+        allow_float_truncate(!safe) {}
+
+  static CastOptions Safe() { return CastOptions(true); }
+
+  static CastOptions Unsafe() { return CastOptions(false); }
 
   bool allow_int_overflow;
   bool allow_time_truncate;
+  bool allow_float_truncate;
 };
 
 /// \since 0.7.0

Reply via email to