westonpace commented on code in PR #33775:
URL: https://github.com/apache/arrow/pull/33775#discussion_r1083118332
##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
Result<Datum> Round(const Datum& arg, RoundOptions options =
RoundOptions::Defaults(),
ExecContext* ctx = NULLPTR);
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
Review Comment:
Which argument? I assume the output is null if either argument is null?
Can we be more explicit.
##########
cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc:
##########
@@ -113,7 +112,7 @@ class TestBaseUnaryRoundArithmetic : public ::testing::Test
{
// (Array, Array)
void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Array>& arg,
const std::shared_ptr<Array>& expected) {
- ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
+ ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr))
Review Comment:
While this does compile and work we try and add a `;` to the end of these
kinds of macros anyways for readability
```suggestion
ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
```
##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -790,6 +796,30 @@ ExtensionIdRegistry::SubstraitCallToArrow
DecodeOptionlessUncheckedArithmetic(
};
}
+ExtensionIdRegistry::SubstraitCallToArrow DecodeBinaryRoundingMode(
+ const std::string& function_name) {
+ return [function_name](const SubstraitCall& call) ->
Result<compute::Expression> {
+ ARROW_ASSIGN_OR_RAISE(
+ compute::RoundMode round_mode,
+ ParseOptionOrElse(
+ call, "rounding", kRoundModeParser,
+ {compute::RoundMode::DOWN, compute::RoundMode::UP,
+ compute::RoundMode::TOWARDS_ZERO,
compute::RoundMode::TOWARDS_INFINITY,
+ compute::RoundMode::HALF_DOWN, compute::RoundMode::HALF_UP,
+ compute::RoundMode::HALF_TOWARDS_ZERO,
+ compute::RoundMode::HALF_TOWARDS_INFINITY,
compute::RoundMode::HALF_TO_EVEN,
+ compute::RoundMode::HALF_TO_ODD},
+ compute::RoundMode::HALF_TOWARDS_INFINITY));
+ ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
+ GetValueArgs(call, 0));
+ std::shared_ptr<compute::RoundBinaryOptions> options =
Review Comment:
Do we want to optimize and call the unary round if the second value is a
scalar? If not in this PR can we create a follow-up github issue so we don't
lose track of it? Or maybe round_binary itself can fallback to unary rounding
if the second argument is scalar.
##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
Result<Datum> Round(const Datum& arg, RoundOptions options =
RoundOptions::Defaults(),
ExecContext* ctx = NULLPTR);
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg1 the value rounded
Review Comment:
```suggestion
/// \param[in] arg1 the value to be rounded
```
##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
Result<Datum> Round(const Datum& arg, RoundOptions options =
RoundOptions::Defaults(),
ExecContext* ctx = NULLPTR);
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg1 the value rounded
+/// \param[in] arg2 the number of significant digits to round to
+/// \param[in] options rounding options (rounding mode and number of digits),
optional
Review Comment:
```suggestion
/// \param[in] options rounding options (rounding mode), optional
```
Or just get rid of the parentheses section entirely.
##########
cpp/src/arrow/compute/api_scalar.h:
##########
@@ -882,6 +891,20 @@ ARROW_EXPORT
Result<Datum> Round(const Datum& arg, RoundOptions options =
RoundOptions::Defaults(),
ExecContext* ctx = NULLPTR);
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg1 the value rounded
+/// \param[in] arg2 the number of significant digits to round to
Review Comment:
Can this be negative? Do we define elsewhere what that entails?
##########
cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc:
##########
@@ -18,7 +18,6 @@
#include <algorithm>
#include <cmath>
#include <memory>
-#include <string>
Review Comment:
Our guideline for includes is [`iwyu`](https://include-what-you-use.org/).
We don't always follow it perfectly (the conformance tool doesn't like type_fwd
files) but it is what we aim for. Please don't remove includes if they are
used in the file (I still see many instances of `std::string`) even if the file
compiles otherwise (transitive includes are potentially unstable).
##########
cpp/src/arrow/compute/kernels/scalar_round.cc:
##########
@@ -751,60 +877,25 @@ ArrayKernelExec
GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id)
}
}
-struct ArithmeticFunction : ScalarFunction {
+struct RoundFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;
Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const
override {
RETURN_NOT_OK(CheckArity(types->size()));
- RETURN_NOT_OK(CheckDecimals(types));
-
using arrow::compute::detail::DispatchExactImpl;
if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
EnsureDictionaryDecoded(types);
- // Only promote types for binary functions
- if (types->size() == 2) {
- ReplaceNullWithOtherType(types);
- TimeUnit::type finest_unit;
- if (CommonTemporalResolution(types->data(), types->size(),
&finest_unit)) {
- ReplaceTemporalTypes(finest_unit, types);
- } else {
- if (TypeHolder type = CommonNumeric(*types)) {
- ReplaceTypes(type, types);
- }
- }
- }
-
if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *types);
}
-
- Status CheckDecimals(std::vector<TypeHolder>* types) const {
- if (!HasDecimal(*types)) return Status::OK();
-
- if (types->size() == 2) {
- // "add_checked" -> "add"
- const auto func_name = name();
- const std::string op = func_name.substr(0, func_name.find("_"));
- if (op == "add" || op == "subtract") {
- return CastBinaryDecimalArgs(DecimalPromotion::kAdd, types);
- } else if (op == "multiply") {
- return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, types);
- } else if (op == "divide") {
- return CastBinaryDecimalArgs(DecimalPromotion::kDivide, types);
- } else {
- return Status::Invalid("Invalid decimal function: ", func_name);
- }
- }
- return Status::OK();
- }
};
-/// An ArithmeticFunction that promotes only decimal arguments to double.
-struct ArithmeticDecimalToFloatingPointFunction : public ArithmeticFunction {
- using ArithmeticFunction::ArithmeticFunction;
+/// An RoundFunction that promotes only decimal arguments to double.
Review Comment:
```suggestion
/// A RoundFunction that promotes only decimal arguments to double.
```
##########
cpp/src/arrow/compute/kernels/scalar_round.cc:
##########
@@ -452,6 +468,127 @@ struct Round<ArrowType, kRoundMode,
enable_if_decimal<ArrowType>> {
}
};
+template <typename ArrowType, RoundMode RndMode, typename Enable = void>
+struct RoundBinary {
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using State = RoundOptionsWrapper<RoundBinaryOptions>;
+
+ explicit RoundBinary(const State& state, const DataType& out_ty) {}
+
+ template <typename T = ArrowType, typename CType0 = typename
TypeTraits<T>::CType0,
+ typename CType1 = typename TypeTraits<T>::CType1>
+ enable_if_floating_value<CType> Call(KernelContext* ctx, CType0 arg0, CType1
arg1,
+ Status* st) const {
+ // Do not process Inf or NaN because they will trigger the overflow error
at end of
+ // function.
Review Comment:
Do you have any tests with infinite or NaN?
##########
cpp/src/arrow/compute/kernels/scalar_round_arithmetic_test.cc:
##########
@@ -965,6 +1086,97 @@ TYPED_TEST(TestUnaryRoundFloating, Round) {
}
}
+TYPED_TEST_SUITE(TestBinaryRoundIntegral, IntegralTypes);
+TYPED_TEST_SUITE(TestBinaryRoundSigned, SignedIntegerTypes);
+TYPED_TEST_SUITE(TestBinaryRoundUnsigned, UnsignedIntegerTypes);
+TYPED_TEST_SUITE(TestBinaryRoundFloating, FloatingTypes);
+
+TYPED_TEST(TestBinaryRoundSigned, Round) {
+ // Test different rounding modes for integer rounding
+ std::string values("[0, 1, -13, -50, 115]");
+ for (const auto& round_mode : kRoundModes) {
+ this->SetRoundMode(round_mode);
+ this->AssertBinaryOp(RoundBinary, values, 0, ArrayFromJSON(float64(),
values));
+ }
+
+ // Test different round N-digits for nearest rounding mode
+ std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{
+ {-2, "[0.0, 0.0, -0.0, -100, 100]"},
+ {-1, "[0.0, 0.0, -10, -50, 120]"},
+ {0, values},
+ {1, values},
+ {2, values},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : ndigits_and_expected) {
+ this->AssertBinaryOp(RoundBinary, values, pair.first,
+ ArrayFromJSON(float64(), pair.second));
+ }
+}
+
+TYPED_TEST(TestBinaryRoundUnsigned, Round) {
+ // Test different rounding modes for integer rounding
+ std::string values("[0, 1, 13, 50, 115]");
+ for (const auto& round_mode : kRoundModes) {
+ this->SetRoundMode(round_mode);
+ this->AssertBinaryOp(RoundBinary, values, 0, ArrayFromJSON(float64(),
values));
+ }
+
+ // Test different round N-digits for nearest rounding mode
+ std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{
+ {-2, "[0, 0, 0, 100, 100]"},
+ {-1, "[0, 0, 10, 50, 120]"},
+ {0, values},
+ {1, values},
+ {2, values},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : ndigits_and_expected) {
+ this->AssertBinaryOp(RoundBinary, values, pair.first,
+ ArrayFromJSON(float64(), pair.second));
+ }
+}
+
+TYPED_TEST(TestBinaryRoundFloating, Round) {
+ this->SetNansEqual(true);
+
+ // Test different rounding modes
+ std::string values("[3.2, 3.5, 3.7, 4.5, -3.2, -3.5, -3.7]");
+ std::vector<std::pair<RoundMode, std::string>> rmode_and_expected{{
+ {RoundMode::DOWN, "[3, 3, 3, 4, -4, -4, -4]"},
+ {RoundMode::UP, "[4, 4, 4, 5, -3, -3, -3]"},
+ {RoundMode::TOWARDS_ZERO, "[3, 3, 3, 4, -3, -3, -3]"},
+ {RoundMode::TOWARDS_INFINITY, "[4, 4, 4, 5, -4, -4, -4]"},
+ {RoundMode::HALF_DOWN, "[3, 3, 4, 4, -3, -4, -4]"},
+ {RoundMode::HALF_UP, "[3, 4, 4, 5, -3, -3, -4]"},
+ {RoundMode::HALF_TOWARDS_ZERO, "[3, 3, 4, 4, -3, -3, -4]"},
+ {RoundMode::HALF_TOWARDS_INFINITY, "[3, 4, 4, 5, -3, -4, -4]"},
+ {RoundMode::HALF_TO_EVEN, "[3, 4, 4, 4, -3, -4, -4]"},
+ {RoundMode::HALF_TO_ODD, "[3, 3, 4, 5, -3, -3, -4]"},
+ }};
+ for (const auto& pair : rmode_and_expected) {
+ this->SetRoundMode(pair.first);
+ this->AssertBinaryOp(RoundBinary, "[]", "[]", "[]");
+ this->AssertBinaryOp(RoundBinary, "[null, 0, Inf, -Inf, NaN, -NaN]",
+ "[0, 0, 0, 0, 0, 0]", "[null, 0, Inf, -Inf, NaN,
-NaN]");
+ this->AssertBinaryOp(RoundBinary, values, 0, pair.second);
+ }
+
+ // Test different round N-digits for nearest rounding mode
+ values = "[320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045]";
+ std::vector<std::pair<int32_t, std::string>> ndigits_and_expected{{
+ {-2, "[300, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0]"},
+ {-1, "[320, 0.0, 0.0, 0.0, -0.0, -40, -0.0]"},
+ {0, "[320, 4, 3, 5, -3, -35, -3]"},
+ {1, "[320, 3.5, 3.1, 4.5, -3.2, -35.1, -3]"},
+ {2, "[320, 3.5, 3.08, 4.5, -3.21, -35.12, -3.05]"},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : ndigits_and_expected) {
+ this->AssertBinaryOp(RoundBinary, values, pair.first, pair.second);
+ }
+}
Review Comment:
Can you add some unit tests that consider nulls in both the values and the
num_digits arguments? Also maybe a few tests with scalars (esp. using a scalar
for num_digits and an array for values which should be equivalent to unary
rounding)
##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -790,6 +796,30 @@ ExtensionIdRegistry::SubstraitCallToArrow
DecodeOptionlessUncheckedArithmetic(
};
}
+ExtensionIdRegistry::SubstraitCallToArrow DecodeBinaryRoundingMode(
+ const std::string& function_name) {
+ return [function_name](const SubstraitCall& call) ->
Result<compute::Expression> {
+ ARROW_ASSIGN_OR_RAISE(
+ compute::RoundMode round_mode,
+ ParseOptionOrElse(
+ call, "rounding", kRoundModeParser,
+ {compute::RoundMode::DOWN, compute::RoundMode::UP,
+ compute::RoundMode::TOWARDS_ZERO,
compute::RoundMode::TOWARDS_INFINITY,
+ compute::RoundMode::HALF_DOWN, compute::RoundMode::HALF_UP,
+ compute::RoundMode::HALF_TOWARDS_ZERO,
+ compute::RoundMode::HALF_TOWARDS_INFINITY,
compute::RoundMode::HALF_TO_EVEN,
+ compute::RoundMode::HALF_TO_ODD},
+ compute::RoundMode::HALF_TOWARDS_INFINITY));
Review Comment:
It appears you are defaulting to `HALF_TOWARDS_INFINITY` but shouldn't the
default be `HALF_TO_EVEN`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]