zanmato1984 commented on PR #44184:
URL: https://github.com/apache/arrow/pull/44184#issuecomment-2956606467

   Thank you @khwilson for the update. I still think it worth to further refine 
the compile-time constant into runtime check. I create a little patch below, 
does that look reasonable to you? Thanks.
   
   ```diff
   diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc 
b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
   index 31c5ca9a42..f5feb17bf2 100644
   --- a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
   +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
   @@ -157,7 +157,7 @@ struct NullSumImpl : public NullImpl<ArrowType> {
      }
    };
    
   -template <template <typename> class KernelClass, bool PromoteDecimal = true>
   +template <template <typename> class KernelClass>
    struct SumLikeInit {
      std::unique_ptr<KernelState> state;
      KernelContext* ctx;
   @@ -191,7 +191,7 @@ struct SumLikeInit {
      /// However, this may not be the desired behaviour (see, e.g., 
MeanKernelInit)
      template <typename Type>
      enable_if_decimal<Type, Status> Visit(const Type&) {
   -    if constexpr (PromoteDecimal) {
   +    if (PromoteDecimal()) {
          ARROW_ASSIGN_OR_RAISE(auto ty, WidenDecimalToMaxPrecision(type));
          state.reset(new KernelClass<Type>(ty, options));
          return Status::OK();
   @@ -210,6 +210,11 @@ struct SumLikeInit {
        ARROW_RETURN_NOT_OK(VisitTypeInline(*type, this));
        return std::move(state);
      }
   +
   +  /// Wether the resulting decimal type should be widened to the maximum 
precision.
   +  /// XXX Ideally this should be able to be configured in the function 
options, e.g.,
   +  /// enum PrecisionPolicy {PROMOTE_TO_MAX, DEMOTE_TO_DOUBLE, NO_PROMOTION};
   +  virtual bool PromoteDecimal() const { return true; }
    };
    
    // ----------------------------------------------------------------------
   @@ -278,15 +283,17 @@ struct MeanImpl<ArrowType, SimdLevel,
    };
    
    template <template <typename> class KernelClass>
   -struct MeanKernelInit : public SumLikeInit<KernelClass, 
/*PromoteDecimal=*/false> {
   +struct MeanKernelInit : public SumLikeInit<KernelClass> {
      MeanKernelInit(KernelContext* ctx, std::shared_ptr<DataType> type,
                     const ScalarAggregateOptions& options)
   -      : SumLikeInit<KernelClass, /*PromoteDecimal=*/false>(ctx, type, 
options) {}
   +      : SumLikeInit<KernelClass>(ctx, type, options) {}
    
      Status Visit(const NullType&) override {
        this->state.reset(new NullSumImpl<DoubleType>(this->options));
        return Status::OK();
      }
   +
   +  bool PromoteDecimal() const override { return false; }
    };
    
    // ----------------------------------------------------------------------
   diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc 
b/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc
   index bdb1db7c92..81f017679e 100644
   --- a/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc
   +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc
   @@ -42,8 +42,7 @@ namespace {
    // Sum/Mean/Product implementation
    
    template <typename Type, typename Impl,
   -          typename AccumulateType = typename 
FindAccumulatorType<Type>::Type,
   -          bool PromoteDecimal = false>
   +          typename AccumulateType = typename 
FindAccumulatorType<Type>::Type>
    struct GroupedReducingAggregator : public GroupedAggregator {
      using AccType = AccumulateType;
      using CType = typename TypeTraits<AccType>::CType;
   @@ -86,8 +85,7 @@ struct GroupedReducingAggregator : public 
GroupedAggregator {
      Status Merge(GroupedAggregator&& raw_other,
                   const ArrayData& group_id_mapping) override {
        auto other =
   -        checked_cast<GroupedReducingAggregator<Type, Impl, AccType, 
PromoteDecimal>*>(
   -            &raw_other);
   +        checked_cast<GroupedReducingAggregator<Type, Impl, 
AccType>*>(&raw_other);
    
        CType* reduced = reduced_.mutable_data();
        int64_t* counts = counts_.mutable_data();
   @@ -156,21 +154,26 @@ struct GroupedReducingAggregator : public 
GroupedAggregator {
      std::shared_ptr<DataType> out_type() const override { return out_type_; }
    
      template <typename T = Type>
   -  static enable_if_t<!is_decimal_type<T>::value, 
Result<std::shared_ptr<DataType>>>
   -  GetOutType(const std::shared_ptr<DataType>& in_type) {
   +  enable_if_t<!is_decimal_type<T>::value, 
Result<std::shared_ptr<DataType>>> GetOutType(
   +      const std::shared_ptr<DataType>& in_type) {
        return TypeTraits<AccType>::type_singleton();
      }
    
      template <typename T = Type>
   -  static enable_if_decimal<T, Result<std::shared_ptr<DataType>>> GetOutType(
   +  enable_if_decimal<T, Result<std::shared_ptr<DataType>>> GetOutType(
          const std::shared_ptr<DataType>& in_type) {
   -    if constexpr (PromoteDecimal) {
   +    if (PromoteDecimal()) {
          return WidenDecimalToMaxPrecision(in_type);
        } else {
          return in_type;
        }
      }
    
   +  /// Wether the resulting decimal type should be widened to the maximum 
precision.
   +  /// XXX Ideally this should be able to be configured in the function 
options, e.g.,
   +  /// enum PrecisionPolicy {PROMOTE_TO_MAX, DEMOTE_TO_DOUBLE, NO_PROMOTION};
   +  virtual bool PromoteDecimal() const { return false; }
   +
      int64_t num_groups_ = 0;
      ScalarAggregateOptions options_;
      TypedBufferBuilder<CType> reduced_;
   @@ -267,11 +270,9 @@ struct GroupedReducingFactory {
    template <typename Type>
    struct GroupedSumImpl
        : public GroupedReducingAggregator<Type, GroupedSumImpl<Type>,
   -                                       typename 
FindAccumulatorType<Type>::Type,
   -                                       /*PromoteDecimal=*/true> {
   +                                       typename 
FindAccumulatorType<Type>::Type> {
      using Base = GroupedReducingAggregator<Type, GroupedSumImpl<Type>,
   -                                         typename 
FindAccumulatorType<Type>::Type,
   -                                         /*PromoteDecimal=*/true>;
   +                                         typename 
FindAccumulatorType<Type>::Type>;
      using CType = typename Base::CType;
      using InputCType = typename Base::InputCType;
    
   @@ -288,6 +289,8 @@ struct GroupedSumImpl
        return static_cast<CType>(to_unsigned(u) + to_unsigned(v));
      }
    
   +  bool PromoteDecimal() const override { return true; }
   +
      using Base::Finish;
    };
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to