edponce commented on a change in pull request #10349:
URL: https://github.com/apache/arrow/pull/10349#discussion_r703994513



##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
##########
@@ -852,24 +854,232 @@ struct LogbChecked {
   }
 };
 
+struct RoundUtil {
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsApproxEqual(
+      const T a, const T b, const T abs_tol = kDefaultAbsoluteTolerance) {
+    return std::fabs(a - b) <= abs_tol;
+  }
+
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsApproxInt(
+      const T val, const T abs_tol = kDefaultAbsoluteTolerance) {
+    // |frac| ~ 0.0?
+    return IsApproxEqual(val, std::round(val), abs_tol);
+  }
+
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsApproxHalfInt(
+      const T val, const T abs_tol = kDefaultAbsoluteTolerance) {
+    // |frac| ~ 0.5?
+    return IsApproxEqual(val - std::floor(val), T(0.5), abs_tol);
+  }
+
+  template <typename T>
+  static enable_if_floating_point<T> Pow10(const int64_t power) {
+    const T lut[]{1e0F,  1e1F,  1e2F,  1e3F,  1e4F,  1e5F,  1e6F,  1e7F,
+                  1e8F,  1e9F,  1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F,
+                  1e16F, 1e17F, 1e18F, 1e19F, 1e20F, 1e21F, 1e22F};
+    // Return NaN if index is out-of-range.
+    auto lut_size = (int64_t)(sizeof(lut) / sizeof(*lut));
+    return (power >= 0 && power < lut_size) ? lut[power] : std::nanf("");
+  }
+};
+
+// Specializations of rounding implementations for round kernels
+template <typename T, RoundMode>
+struct RoundImpl {};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::DOWN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::floor(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::UP> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::ceil(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::TOWARDS_ZERO> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::signbit(val) ? std::ceil(val) : std::floor(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::TOWARDS_INFINITY> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::signbit(val) ? std::floor(val) : std::ceil(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_DOWN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::DOWN>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_UP> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::UP>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TOWARDS_ZERO> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TOWARDS_INFINITY> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::TOWARDS_INFINITY>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TO_EVEN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::round(val * T(0.5)) * 2;
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TO_ODD> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::floor(val * T(0.5)) + std::ceil(val * T(0.5));
+  }
+};
+
+template <RoundMode RndMode>
+struct Round {
+  using RoundState = OptionsWrapper<RoundOptions>;
+
+  template <typename T, typename Arg>
+  static enable_if_floating_point<Arg, T> Call(KernelContext* ctx, Arg arg, 
Status* st) {
+    static_assert(std::is_same<T, Arg>::value, "");
+    static_assert(RoundMode::HALF_DOWN > RoundMode::DOWN &&
+                      RoundMode::HALF_DOWN > RoundMode::UP &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_INFINITY &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_UP &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_INFINITY 
&&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_EVEN &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_ODD,
+                  "Round modes prefixed with HALF need to be defined last in 
enum and "
+                  "the first HALF entry has to be HALF_DOWN.");
+
+    auto options = RoundState::Get(ctx);
+    auto pow10 = RoundUtil::Pow10<T>(std::llabs(options.ndigits));
+    if (std::isnan(pow10)) {
+      *st = Status::Invalid("out-of-range value for rounding digits");
+      return arg;
+    } else if (!std::isfinite(arg)) {
+      return arg;
+    }
+
+    T scaled_arg = (options.ndigits >= 0) ? (arg * pow10) : (arg / pow10);
+    // Use std::round if scaled value is an integer or not 0.5 when a 
tie-breaking mode
+    // was set.
+    T result;
+    if (RoundUtil::IsApproxInt(scaled_arg, T(options.abs_tol)) ||
+        (options.round_mode >= RoundMode::HALF_DOWN &&
+         !RoundUtil::IsApproxHalfInt(scaled_arg, T(options.abs_tol)))) {
+      result = std::round(scaled_arg);
+    } else {
+      result = RoundImpl<T, RndMode>::Round(scaled_arg);
+    }
+    result = (options.ndigits >= 0) ? (result / pow10) : (result * pow10);
+    if (!std::isfinite(result)) {
+      *st = Status::Invalid("overflow occurred during rounding");
+      return arg;
+    }
+    // If rounding didn't change value, return original value
+    return RoundUtil::IsApproxEqual(arg, result, T(options.abs_tol)) ? arg : 
result;
+  }
+};
+
+template <RoundMode RndMode>
+struct MRound {
+  using MRoundState = OptionsWrapper<MRoundOptions>;
+
+  template <typename T, typename Arg>
+  static enable_if_floating_point<Arg, T> Call(KernelContext* ctx, Arg arg, 
Status* st) {
+    static_assert(std::is_same<T, Arg>::value, "");
+    static_assert(RoundMode::HALF_DOWN > RoundMode::DOWN &&
+                      RoundMode::HALF_DOWN > RoundMode::UP &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN > RoundMode::TOWARDS_INFINITY &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_UP &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_ZERO &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_INFINITY 
&&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_EVEN &&
+                      RoundMode::HALF_DOWN < RoundMode::HALF_TO_ODD,
+                  "Round modes prefixed with HALF need to be defined last in 
enum and "
+                  "the first HALF entry has to be HALF_DOWN.");
+    auto options = MRoundState::Get(ctx);
+    auto mult = std::fabs(T(options.multiple));
+    if (mult == 0) {

Review comment:
       Well, the only value not allowed is `mult = 0`. We can mandate that 
`mult` is non-zero and remove the check, but I think we should leave as-is for 
now. Ideally, we would be able to have multiple CallXXX variants for a kernel 
where they would be selected based on options. In this case, if `mult` is zero, 
then it would invoke
   ```c++
   static ... CallZero(...) {
         return std::isfinite(arg) ? 0 : std::nanf("");
   }
   ```
   and all other cases would invoke the current `Call()`. This is not the only 
compute function that has special cases, actually most of them have, and having 
this capability will increase performance and also make the code a bit more 
amenable for SIMD. I am working on a refactoring to support these ideas.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to