pitrou commented on a change in pull request #10349:
URL: https://github.com/apache/arrow/pull/10349#discussion_r704420538



##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
##########
@@ -852,24 +854,232 @@ struct LogbChecked {
   }
 };
 
+struct RoundUtil {
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsApproxEqual(
+      const T a, const T b, const T abs_tol = kDefaultAbsoluteTolerance) {
+    return std::fabs(a - b) <= abs_tol;
+  }
+
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsApproxInt(
+      const T val, const T abs_tol = kDefaultAbsoluteTolerance) {
+    // |frac| ~ 0.0?
+    return IsApproxEqual(val, std::round(val), abs_tol);
+  }
+
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsApproxHalfInt(
+      const T val, const T abs_tol = kDefaultAbsoluteTolerance) {
+    // |frac| ~ 0.5?
+    return IsApproxEqual(val - std::floor(val), T(0.5), abs_tol);
+  }
+
+  template <typename T>
+  static enable_if_floating_point<T> Pow10(const int64_t power) {
+    const T lut[]{1e0F,  1e1F,  1e2F,  1e3F,  1e4F,  1e5F,  1e6F,  1e7F,
+                  1e8F,  1e9F,  1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F,
+                  1e16F, 1e17F, 1e18F, 1e19F, 1e20F, 1e21F, 1e22F};
+    // Return NaN if index is out-of-range.
+    auto lut_size = (int64_t)(sizeof(lut) / sizeof(*lut));
+    return (power >= 0 && power < lut_size) ? lut[power] : std::nanf("");
+  }
+};
+
+// Specializations of rounding implementations for round kernels
+template <typename T, RoundMode>
+struct RoundImpl {};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::DOWN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::floor(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::UP> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::ceil(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::TOWARDS_ZERO> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::signbit(val) ? std::ceil(val) : std::floor(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::TOWARDS_INFINITY> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::signbit(val) ? std::floor(val) : std::ceil(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_DOWN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::DOWN>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_UP> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::UP>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TOWARDS_ZERO> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TOWARDS_INFINITY> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::TOWARDS_INFINITY>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TO_EVEN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::round(val * T(0.5)) * 2;
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TO_ODD> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::floor(val * T(0.5)) + std::ceil(val * T(0.5));
+  }
+};
+
+template <RoundMode RndMode>
+struct Round {
+  using RoundState = OptionsWrapper<RoundOptions>;
+
+  template <typename T, typename Arg>
+  static enable_if_floating_point<Arg, T> Call(KernelContext* ctx, Arg arg, 
Status* st) {
+    static_assert(std::is_same<T, Arg>::value, "");
+    static_assert(RoundMode::HALF_DOWN > RoundMode::DOWN &&
+                      RoundMode::HALF_DOWN > RoundMode::UP &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_INFINITY &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_UP &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_INFINITY 
&&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_EVEN &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_ODD,
+                  "Round modes prefixed with HALF need to be defined last in 
enum and "
+                  "the first HALF entry has to be HALF_DOWN.");
+
+    auto options = RoundState::Get(ctx);
+    auto pow10 = RoundUtil::Pow10<T>(std::llabs(options.ndigits));
+    if (std::isnan(pow10)) {
+      *st = Status::Invalid("out-of-range value for rounding digits");
+      return arg;
+    } else if (!std::isfinite(arg)) {
+      return arg;
+    }
+
+    T scaled_arg = (options.ndigits >= 0) ? (arg * pow10) : (arg / pow10);
+    // Use std::round if scaled value is an integer or not 0.5 when a 
tie-breaking mode
+    // was set.
+    T result;
+    if (RoundUtil::IsApproxInt(scaled_arg, T(options.abs_tol)) ||
+        (options.round_mode >= RoundMode::HALF_DOWN &&
+         !RoundUtil::IsApproxHalfInt(scaled_arg, T(options.abs_tol)))) {
+      result = std::round(scaled_arg);
+    } else {
+      result = RoundImpl<T, RndMode>::Round(scaled_arg);
+    }
+    result = (options.ndigits >= 0) ? (result / pow10) : (result * pow10);
+    if (!std::isfinite(result)) {
+      *st = Status::Invalid("overflow occurred during rounding");
+      return arg;
+    }
+    // If rounding didn't change value, return original value
+    return RoundUtil::IsApproxEqual(arg, result, T(options.abs_tol)) ? arg : 
result;
+  }
+};
+
+template <RoundMode RndMode>
+struct MRound {
+  using MRoundState = OptionsWrapper<MRoundOptions>;
+
+  template <typename T, typename Arg>
+  static enable_if_floating_point<Arg, T> Call(KernelContext* ctx, Arg arg, 
Status* st) {
+    static_assert(std::is_same<T, Arg>::value, "");
+    static_assert(RoundMode::HALF_DOWN > RoundMode::DOWN &&
+                      RoundMode::HALF_DOWN > RoundMode::UP &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_INFINITY &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_UP &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_INFINITY 
&&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_EVEN &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_ODD,
+                  "Round modes prefixed with HALF need to be defined last in 
enum and "
+                  "the first HALF entry has to be HALF_DOWN.");
+    auto options = MRoundState::Get(ctx);
+    auto mult = std::fabs(T(options.multiple));
+    if (mult == 0) {

Review comment:
       My inclination would be to check for `mult > 0` in the kernel init 
function and return an error otherwise. I don't think there's any use case for 
passing zero or a negative number, is there?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to