This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new d18107245e GH-38833: [C++] Avoid hash_mean overflow (#39349)
d18107245e is described below
commit d18107245e8e82a8d7ec40e0ae27f083ffbb7cc4
Author: Jin Shang <[email protected]>
AuthorDate: Thu Jan 11 23:40:49 2024 +0800
GH-38833: [C++] Avoid hash_mean overflow (#39349)
### Rationale for this change
hash_mean overflows if the sum of a group is larger than uint64 max.
### What changes are included in this PR?
Save the intermediate sum as double to avoid overflow
### Are these changes tested?
yes
### Are there any user-facing changes?
no
* Closes: #38833
Authored-by: Jin Shang <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/acero/hash_aggregate_test.cc | 36 +++++++++++++++++++++++++
cpp/src/arrow/compute/kernels/hash_aggregate.cc | 24 ++++++++++++-----
2 files changed, 54 insertions(+), 6 deletions(-)
diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc
b/cpp/src/arrow/acero/hash_aggregate_test.cc
index a4874f3581..2626fd5037 100644
--- a/cpp/src/arrow/acero/hash_aggregate_test.cc
+++ b/cpp/src/arrow/acero/hash_aggregate_test.cc
@@ -1694,6 +1694,42 @@ TEST_P(GroupBy, SumMeanProductScalar) {
}
}
+TEST_P(GroupBy, MeanOverflow) {
+ BatchesWithSchema input;
+ // would overflow if intermediate sum is integer
+ input.batches = {
+ ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR,
ArgShape::ARRAY},
+
+ "[[9223372036854775805, 1], [9223372036854775805, 1], "
+ "[9223372036854775805, 2], [9223372036854775805, 3]]"),
+ ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR,
ArgShape::ARRAY},
+ "[[null, 1], [null, 1], [null, 2], [null, 3]]"),
+ ExecBatchFromJSON({int64(), int64()},
+ "[[9223372036854775805, 1], [9223372036854775805, 2], "
+ "[9223372036854775805, 3]]"),
+ };
+ input.schema = schema({field("argument", int64()), field("key", int64())});
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(Datum actual,
+ RunGroupBy(input, {"key"},
+ {
+ {"hash_mean", nullptr, "argument",
"hash_mean"},
+ },
+ use_threads));
+ Datum expected = ArrayFromJSON(struct_({
+ field("key", int64()),
+ field("hash_mean", float64()),
+ }),
+ R"([
+ [1, 9223372036854775805],
+ [2, 9223372036854775805],
+ [3, 9223372036854775805]
+ ])");
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
TEST_P(GroupBy, VarianceAndStddev) {
auto batch = RecordBatchFromJSON(
schema({field("argument", int32()), field("key", int64())}), R"([
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index c37e45513d..5052d8dd66 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -38,6 +38,7 @@
#include "arrow/compute/row/grouper.h"
#include "arrow/record_batch.h"
#include "arrow/stl_allocator.h"
+#include "arrow/type_traits.h"
#include "arrow/util/bit_run_reader.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_writer.h"
@@ -441,9 +442,10 @@ struct GroupedCountImpl : public GroupedAggregator {
// ----------------------------------------------------------------------
// Sum/Mean/Product implementation
-template <typename Type, typename Impl>
+template <typename Type, typename Impl,
+ typename AccumulateType = typename FindAccumulatorType<Type>::Type>
struct GroupedReducingAggregator : public GroupedAggregator {
- using AccType = typename FindAccumulatorType<Type>::Type;
+ using AccType = AccumulateType;
using CType = typename TypeTraits<AccType>::CType;
using InputCType = typename TypeTraits<Type>::CType;
@@ -483,7 +485,8 @@ struct GroupedReducingAggregator : public GroupedAggregator
{
Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
- auto other = checked_cast<GroupedReducingAggregator<Type,
Impl>*>(&raw_other);
+ auto other =
+ checked_cast<GroupedReducingAggregator<Type, Impl,
AccType>*>(&raw_other);
CType* reduced = reduced_.mutable_data();
int64_t* counts = counts_.mutable_data();
@@ -733,9 +736,18 @@ using GroupedProductFactory =
// ----------------------------------------------------------------------
// Mean implementation
+template <typename T>
+struct GroupedMeanAccType {
+ using Type = typename std::conditional<is_number_type<T>::value, DoubleType,
+ typename
FindAccumulatorType<T>::Type>::type;
+};
+
template <typename Type>
-struct GroupedMeanImpl : public GroupedReducingAggregator<Type,
GroupedMeanImpl<Type>> {
- using Base = GroupedReducingAggregator<Type, GroupedMeanImpl<Type>>;
+struct GroupedMeanImpl
+ : public GroupedReducingAggregator<Type, GroupedMeanImpl<Type>,
+ typename
GroupedMeanAccType<Type>::Type> {
+ using Base = GroupedReducingAggregator<Type, GroupedMeanImpl<Type>,
+ typename
GroupedMeanAccType<Type>::Type>;
using CType = typename Base::CType;
using InputCType = typename Base::InputCType;
using MeanType =
@@ -746,7 +758,7 @@ struct GroupedMeanImpl : public
GroupedReducingAggregator<Type, GroupedMeanImpl<
template <typename T = Type>
static enable_if_number<T, CType> Reduce(const DataType&, const CType u,
const InputCType v) {
- return static_cast<CType>(to_unsigned(u) +
to_unsigned(static_cast<CType>(v)));
+ return static_cast<CType>(u) + static_cast<CType>(v);
}
static CType Reduce(const DataType&, const CType u, const CType v) {