This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d18107245e GH-38833: [C++] Avoid hash_mean overflow (#39349)
d18107245e is described below

commit d18107245e8e82a8d7ec40e0ae27f083ffbb7cc4
Author: Jin Shang <[email protected]>
AuthorDate: Thu Jan 11 23:40:49 2024 +0800

    GH-38833: [C++] Avoid hash_mean overflow (#39349)
    
    
    
    ### Rationale for this change
    
    hash_mean overflows if the sum of a group is larger than uint64 max.
    
    ### What changes are included in this PR?
    
    Save the intermediate sum as double to avoid overflow
    
    ### Are these changes tested?
    
    yes
    
    ### Are there any user-facing changes?
     no
    
    * Closes: #38833
    
    Authored-by: Jin Shang <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 cpp/src/arrow/acero/hash_aggregate_test.cc      | 36 +++++++++++++++++++++++++
 cpp/src/arrow/compute/kernels/hash_aggregate.cc | 24 ++++++++++++-----
 2 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc 
b/cpp/src/arrow/acero/hash_aggregate_test.cc
index a4874f3581..2626fd5037 100644
--- a/cpp/src/arrow/acero/hash_aggregate_test.cc
+++ b/cpp/src/arrow/acero/hash_aggregate_test.cc
@@ -1694,6 +1694,42 @@ TEST_P(GroupBy, SumMeanProductScalar) {
   }
 }
 
+TEST_P(GroupBy, MeanOverflow) {
+  BatchesWithSchema input;
+  // would overflow if intermediate sum is integer
+  input.batches = {
+      ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, 
ArgShape::ARRAY},
+
+                        "[[9223372036854775805, 1], [9223372036854775805, 1], "
+                        "[9223372036854775805, 2], [9223372036854775805, 3]]"),
+      ExecBatchFromJSON({int64(), int64()}, {ArgShape::SCALAR, 
ArgShape::ARRAY},
+                        "[[null, 1], [null, 1], [null, 2], [null, 3]]"),
+      ExecBatchFromJSON({int64(), int64()},
+                        "[[9223372036854775805, 1], [9223372036854775805, 2], "
+                        "[9223372036854775805, 3]]"),
+  };
+  input.schema = schema({field("argument", int64()), field("key", int64())});
+  for (bool use_threads : {true, false}) {
+    SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+    ASSERT_OK_AND_ASSIGN(Datum actual,
+                         RunGroupBy(input, {"key"},
+                                    {
+                                        {"hash_mean", nullptr, "argument", 
"hash_mean"},
+                                    },
+                                    use_threads));
+    Datum expected = ArrayFromJSON(struct_({
+                                       field("key", int64()),
+                                       field("hash_mean", float64()),
+                                   }),
+                                   R"([
+      [1, 9223372036854775805],
+      [2, 9223372036854775805],
+      [3, 9223372036854775805]
+    ])");
+    AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+  }
+}
+
 TEST_P(GroupBy, VarianceAndStddev) {
   auto batch = RecordBatchFromJSON(
       schema({field("argument", int32()), field("key", int64())}), R"([
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc 
b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index c37e45513d..5052d8dd66 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -38,6 +38,7 @@
 #include "arrow/compute/row/grouper.h"
 #include "arrow/record_batch.h"
 #include "arrow/stl_allocator.h"
+#include "arrow/type_traits.h"
 #include "arrow/util/bit_run_reader.h"
 #include "arrow/util/bitmap_ops.h"
 #include "arrow/util/bitmap_writer.h"
@@ -441,9 +442,10 @@ struct GroupedCountImpl : public GroupedAggregator {
 // ----------------------------------------------------------------------
 // Sum/Mean/Product implementation
 
-template <typename Type, typename Impl>
+template <typename Type, typename Impl,
+          typename AccumulateType = typename FindAccumulatorType<Type>::Type>
 struct GroupedReducingAggregator : public GroupedAggregator {
-  using AccType = typename FindAccumulatorType<Type>::Type;
+  using AccType = AccumulateType;
   using CType = typename TypeTraits<AccType>::CType;
   using InputCType = typename TypeTraits<Type>::CType;
 
@@ -483,7 +485,8 @@ struct GroupedReducingAggregator : public GroupedAggregator 
{
 
   Status Merge(GroupedAggregator&& raw_other,
                const ArrayData& group_id_mapping) override {
-    auto other = checked_cast<GroupedReducingAggregator<Type, 
Impl>*>(&raw_other);
+    auto other =
+        checked_cast<GroupedReducingAggregator<Type, Impl, 
AccType>*>(&raw_other);
 
     CType* reduced = reduced_.mutable_data();
     int64_t* counts = counts_.mutable_data();
@@ -733,9 +736,18 @@ using GroupedProductFactory =
 // ----------------------------------------------------------------------
 // Mean implementation
 
+template <typename T>
+struct GroupedMeanAccType {
+  using Type = typename std::conditional<is_number_type<T>::value, DoubleType,
+                                         typename 
FindAccumulatorType<T>::Type>::type;
+};
+
 template <typename Type>
-struct GroupedMeanImpl : public GroupedReducingAggregator<Type, 
GroupedMeanImpl<Type>> {
-  using Base = GroupedReducingAggregator<Type, GroupedMeanImpl<Type>>;
+struct GroupedMeanImpl
+    : public GroupedReducingAggregator<Type, GroupedMeanImpl<Type>,
+                                       typename 
GroupedMeanAccType<Type>::Type> {
+  using Base = GroupedReducingAggregator<Type, GroupedMeanImpl<Type>,
+                                         typename 
GroupedMeanAccType<Type>::Type>;
   using CType = typename Base::CType;
   using InputCType = typename Base::InputCType;
   using MeanType =
@@ -746,7 +758,7 @@ struct GroupedMeanImpl : public 
GroupedReducingAggregator<Type, GroupedMeanImpl<
   template <typename T = Type>
   static enable_if_number<T, CType> Reduce(const DataType&, const CType u,
                                            const InputCType v) {
-    return static_cast<CType>(to_unsigned(u) + 
to_unsigned(static_cast<CType>(v)));
+    return static_cast<CType>(u) + static_cast<CType>(v);
   }
 
   static CType Reduce(const DataType&, const CType u, const CType v) {

Reply via email to