lidavidm commented on a change in pull request #10349: URL: https://github.com/apache/arrow/pull/10349#discussion_r674997744
########## File path: cpp/src/arrow/compute/kernels/scalar_arithmetic.cc ########## @@ -817,24 +818,158 @@ struct Log1pChecked { } }; +struct RoundUtil { + template <typename T, enable_if_t<std::is_floating_point<T>::value, bool> = true> + static constexpr bool ApproxEqual(const T x, const T y) { + return (x == y) || (std::fabs(x - y) <= std::numeric_limits<T>::epsilon()); + } + + template <typename T, enable_if_t<std::is_floating_point<T>::value, bool> = true> + static constexpr bool IsHalf(T val) { + // |frac| == 0.5? + return ApproxEqual(std::fmod(std::fabs(val), T(1)), T(0.5)); + } + + template <typename T> + static enable_if_floating_point<T> Pow10(const int power) { + static constexpr auto pow10 = std::array<T, 39>{ + 1e-19, 1e-18, 1e-17, 1e-16, 1e-15, 1e-14, 1e-13, 1e-12, 1e-11, 1e-10, + 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, + 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, + 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19}; + return pow10.at(power + 19); + } +}; + +// Specializations of rounding implementations for kernels +template <typename T, RoundMode RndMode> +struct RoundImpl { + static constexpr enable_if_floating_point<T> Round(T) { return T(0); } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::TOWARDS_NEG_INFINITY> { + static constexpr enable_if_floating_point<T> Round(T val) { return std::floor(val); } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::TOWARDS_POS_INFINITY> { + static constexpr enable_if_floating_point<T> Round(T val) { return std::ceil(val); } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::TOWARDS_ZERO> { + static constexpr enable_if_floating_point<T> Round(T val) { return std::trunc(val); } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::TOWARDS_INFINITY> { + static constexpr enable_if_floating_point<T> Round(T val) { + return std::signbit(val) ? std::floor(val) : std::ceil(val); + } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::HALF_TOWARDS_NEG_INFINITY> { + static constexpr enable_if_floating_point<T> Round(T val) { + return std::ceil(val - T(0.5)); + } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::HALF_TOWARDS_POS_INFINITY> { + static constexpr enable_if_floating_point<T> Round(T val) { + return std::floor(val + T(0.5)); + } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::HALF_TOWARDS_ZERO> { + static constexpr enable_if_floating_point<T> Round(T val) { + return std::copysign(std::ceil(std::fabs(val) - T(0.5)), val); + } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::HALF_TOWARDS_INFINITY> { + static enable_if_floating_point<T> Round(T val) { + return std::copysign(std::floor(std::fabs(val) + T(0.5)), val); + } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::HALF_TO_EVEN> { + static enable_if_floating_point<T> Round(T val) { + if (!RoundUtil::IsHalf(val)) { + return std::round(val); + } + auto floor = std::floor(val); + // Odd + 1, Even + 0 + return floor + T(std::fmod(std::fabs(floor), T(2)) >= T(1)); + } +}; + +template <typename T> +struct RoundImpl<T, RoundMode::HALF_TO_ODD> { + static enable_if_floating_point<T> Round(T val) { + if (!RoundUtil::IsHalf(val)) { + return std::round(val); + } + auto floor = std::floor(val); + // Odd + 0, Even + 1 + return floor + T(std::fmod(std::fabs(floor), T(2)) < T(1)); + } +}; + +template <RoundMode RndMode> +struct MRound { + template <typename T, typename Arg> + static enable_if_floating_point<Arg, T> Call(KernelContext* ctx, Arg arg, Status*) { + static_assert(std::is_same<T, Arg>::value, ""); + if (std::isnan(arg)) { + return arg; + } + auto options = OptionsWrapper<MRoundOptions>::Get(ctx); + const auto mult = std::fabs(T(options.multiple)); + return (mult == T(0)) ? T(0) : (RoundImpl<T, RndMode>::Round(arg / mult) * mult); + } +}; + +template <RoundMode RndMode> +struct Round { + template <typename T, typename Arg> + static enable_if_floating_point<Arg, T> Call(KernelContext* ctx, Arg arg, Status*) { + static_assert(std::is_same<T, Arg>::value, ""); + auto options = OptionsWrapper<RoundOptions>::Get(ctx); + const auto mult = RoundUtil::Pow10<T>(-options.ndigits); Review comment: I would say either a checked and unchecked variant, or else raise an error (since really, it's about the option being invalid). I don't think it's too useful to round to 0 or not round. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org