edponce commented on a change in pull request #10349:
URL: https://github.com/apache/arrow/pull/10349#discussion_r705496648



##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
##########
@@ -852,24 +853,243 @@ struct LogbChecked {
   }
 };
 
+struct RoundUtil {
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsInteger(
+      const T val) {
+    // |frac| ~ 0.0?
+    return std::floor(val) == val;
+  }
+
+  template <typename T>
+  static constexpr enable_if_t<std::is_floating_point<T>::value, bool> 
IsHalfInteger(
+      const T val) {
+    // |frac| ~ 0.5?
+    return (val - std::floor(val)) == T(0.5);
+  }
+
+  // Calculate powers of ten with arbitrary integer exponent
+  template <typename T = double>
+  static enable_if_floating_point<T> Pow10(int64_t power) {
+    static constexpr T lut[] = {1e0F, 1e1F, 1e2F,  1e3F,  1e4F,  1e5F,  1e6F,  
1e7F,
+                                1e8F, 1e9F, 1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 
1e15F};
+    int64_t lut_size = (sizeof(lut) / sizeof(*lut));
+    int64_t abs_power = std::abs(power);
+    auto pow10 = lut[std::min(abs_power, lut_size - 1)];
+    while (abs_power-- >= lut_size) {
+      pow10 *= 1e1F;
+    }
+    return (power >= 0) ? pow10 : (1 / pow10);
+  }
+};
+
+// Specializations of rounding implementations for round kernels
+template <typename, RoundMode>
+struct RoundImpl;
+
+template <typename T>
+struct RoundImpl<T, RoundMode::DOWN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::floor(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::UP> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::ceil(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::TOWARDS_ZERO> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::trunc(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::TOWARDS_INFINITY> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::signbit(val) ? std::floor(val) : std::ceil(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_DOWN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::DOWN>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_UP> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::UP>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TOWARDS_ZERO> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TOWARDS_INFINITY> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return RoundImpl<T, RoundMode::TOWARDS_INFINITY>::Round(val);
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TO_EVEN> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::round(val * T(0.5)) * 2;
+  }
+};
+
+template <typename T>
+struct RoundImpl<T, RoundMode::HALF_TO_ODD> {
+  static constexpr enable_if_floating_point<T> Round(const T val) {
+    return std::floor(val * T(0.5)) + std::ceil(val * T(0.5));
+  }
+};
+
+// Specializations of kernel state for round kernels
+template <typename>
+struct RoundOptionsWrapper;
+
+template <>
+struct RoundOptionsWrapper<RoundOptions> : public OptionsWrapper<RoundOptions> 
{
+  using OptionsType = RoundOptions;
+  double pow10;
+
+  explicit RoundOptionsWrapper(OptionsType options) : 
OptionsWrapper(std::move(options)) {
+    // Only positive of powers of 10 are used because combining multiply and
+    // division operations produced more stable rounding than using 
multiply-only.
+    // Refer to NumPy's round implementation:
+    // 
https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589
+    pow10 = RoundUtil::Pow10(std::abs(options.ndigits));
+  }
+
+  static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+                                                   const KernelInitArgs& args) 
{
+    if (auto options = static_cast<const OptionsType*>(args.options)) {
+      return 
arrow::internal::make_unique<RoundOptionsWrapper<OptionsType>>(*options);
+    }
+
+    return Status::Invalid(
+        "Attempted to initialize KernelState from null FunctionOptions");
+  }
+};
+
+template <>
+struct RoundOptionsWrapper<RoundToMultipleOptions>
+    : public OptionsWrapper<RoundToMultipleOptions> {
+  using OptionsType = RoundToMultipleOptions;
+
+  static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+                                                   const KernelInitArgs& args) 
{
+    ARROW_ASSIGN_OR_RAISE(auto state, OptionsWrapper<OptionsType>::Init(ctx, 
args));
+    auto options = Get(*state);
+    if (options.multiple <= 0) {
+      return Status::Invalid("Rounding multiple has to be a non-zero positive 
value");
+    }
+    return std::move(state);
+  }
+};
+
+template <RoundMode RndMode>
+struct Round {
+  using State = RoundOptionsWrapper<RoundOptions>;
+
+  template <typename T, typename Arg>
+  static enable_if_floating_point<Arg, T> Call(KernelContext* ctx, Arg arg, 
Status* st) {
+    static_assert(std::is_same<T, Arg>::value, "");
+    // Do not process Inf or NaN because they will trigger the overflow error 
at end of
+    // function.
+    if (!std::isfinite(arg)) {
+      return arg;
+    }
+    auto state = static_cast<State*>(ctx->state());
+    auto options = state->options;
+    auto pow10 = T(state->pow10);
+    auto round_val = (options.ndigits >= 0) ? (arg * pow10) : (arg / pow10);
+    // Use std::round() if in tie-breaking mode and scaled value is not 0.5.
+    if ((options.round_mode >= RoundMode::HALF_DOWN) &&

Review comment:
       Good catch! Template types with underlying integers can be used in 
"runtime" conditionals. I had done this somewhere else but escaped me here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to