zanmato1984 commented on PR #44184:
URL: https://github.com/apache/arrow/pull/44184#issuecomment-2956606467
Thank you @khwilson for the update. I still think it worth to further refine
the compile-time constant into runtime check. I create a little patch below,
does that look reasonable to you? Thanks.
```diff
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
index 31c5ca9a42..f5feb17bf2 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
@@ -157,7 +157,7 @@ struct NullSumImpl : public NullImpl<ArrowType> {
}
};
-template <template <typename> class KernelClass, bool PromoteDecimal = true>
+template <template <typename> class KernelClass>
struct SumLikeInit {
std::unique_ptr<KernelState> state;
KernelContext* ctx;
@@ -191,7 +191,7 @@ struct SumLikeInit {
/// However, this may not be the desired behaviour (see, e.g.,
MeanKernelInit)
template <typename Type>
enable_if_decimal<Type, Status> Visit(const Type&) {
- if constexpr (PromoteDecimal) {
+ if (PromoteDecimal()) {
ARROW_ASSIGN_OR_RAISE(auto ty, WidenDecimalToMaxPrecision(type));
state.reset(new KernelClass<Type>(ty, options));
return Status::OK();
@@ -210,6 +210,11 @@ struct SumLikeInit {
ARROW_RETURN_NOT_OK(VisitTypeInline(*type, this));
return std::move(state);
}
+
+ /// Wether the resulting decimal type should be widened to the maximum
precision.
+ /// XXX Ideally this should be able to be configured in the function
options, e.g.,
+ /// enum PrecisionPolicy {PROMOTE_TO_MAX, DEMOTE_TO_DOUBLE, NO_PROMOTION};
+ virtual bool PromoteDecimal() const { return true; }
};
// ----------------------------------------------------------------------
@@ -278,15 +283,17 @@ struct MeanImpl<ArrowType, SimdLevel,
};
template <template <typename> class KernelClass>
-struct MeanKernelInit : public SumLikeInit<KernelClass,
/*PromoteDecimal=*/false> {
+struct MeanKernelInit : public SumLikeInit<KernelClass> {
MeanKernelInit(KernelContext* ctx, std::shared_ptr<DataType> type,
const ScalarAggregateOptions& options)
- : SumLikeInit<KernelClass, /*PromoteDecimal=*/false>(ctx, type,
options) {}
+ : SumLikeInit<KernelClass>(ctx, type, options) {}
Status Visit(const NullType&) override {
this->state.reset(new NullSumImpl<DoubleType>(this->options));
return Status::OK();
}
+
+ bool PromoteDecimal() const override { return false; }
};
// ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc
b/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc
index bdb1db7c92..81f017679e 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_numeric.cc
@@ -42,8 +42,7 @@ namespace {
// Sum/Mean/Product implementation
template <typename Type, typename Impl,
- typename AccumulateType = typename
FindAccumulatorType<Type>::Type,
- bool PromoteDecimal = false>
+ typename AccumulateType = typename
FindAccumulatorType<Type>::Type>
struct GroupedReducingAggregator : public GroupedAggregator {
using AccType = AccumulateType;
using CType = typename TypeTraits<AccType>::CType;
@@ -86,8 +85,7 @@ struct GroupedReducingAggregator : public
GroupedAggregator {
Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
auto other =
- checked_cast<GroupedReducingAggregator<Type, Impl, AccType,
PromoteDecimal>*>(
- &raw_other);
+ checked_cast<GroupedReducingAggregator<Type, Impl,
AccType>*>(&raw_other);
CType* reduced = reduced_.mutable_data();
int64_t* counts = counts_.mutable_data();
@@ -156,21 +154,26 @@ struct GroupedReducingAggregator : public
GroupedAggregator {
std::shared_ptr<DataType> out_type() const override { return out_type_; }
template <typename T = Type>
- static enable_if_t<!is_decimal_type<T>::value,
Result<std::shared_ptr<DataType>>>
- GetOutType(const std::shared_ptr<DataType>& in_type) {
+ enable_if_t<!is_decimal_type<T>::value,
Result<std::shared_ptr<DataType>>> GetOutType(
+ const std::shared_ptr<DataType>& in_type) {
return TypeTraits<AccType>::type_singleton();
}
template <typename T = Type>
- static enable_if_decimal<T, Result<std::shared_ptr<DataType>>> GetOutType(
+ enable_if_decimal<T, Result<std::shared_ptr<DataType>>> GetOutType(
const std::shared_ptr<DataType>& in_type) {
- if constexpr (PromoteDecimal) {
+ if (PromoteDecimal()) {
return WidenDecimalToMaxPrecision(in_type);
} else {
return in_type;
}
}
+ /// Wether the resulting decimal type should be widened to the maximum
precision.
+ /// XXX Ideally this should be able to be configured in the function
options, e.g.,
+ /// enum PrecisionPolicy {PROMOTE_TO_MAX, DEMOTE_TO_DOUBLE, NO_PROMOTION};
+ virtual bool PromoteDecimal() const { return false; }
+
int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
TypedBufferBuilder<CType> reduced_;
@@ -267,11 +270,9 @@ struct GroupedReducingFactory {
template <typename Type>
struct GroupedSumImpl
: public GroupedReducingAggregator<Type, GroupedSumImpl<Type>,
- typename
FindAccumulatorType<Type>::Type,
- /*PromoteDecimal=*/true> {
+ typename
FindAccumulatorType<Type>::Type> {
using Base = GroupedReducingAggregator<Type, GroupedSumImpl<Type>,
- typename
FindAccumulatorType<Type>::Type,
- /*PromoteDecimal=*/true>;
+ typename
FindAccumulatorType<Type>::Type>;
using CType = typename Base::CType;
using InputCType = typename Base::InputCType;
@@ -288,6 +289,8 @@ struct GroupedSumImpl
return static_cast<CType>(to_unsigned(u) + to_unsigned(v));
}
+ bool PromoteDecimal() const override { return true; }
+
using Base::Finish;
};
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]