This is an automated email from the ASF dual-hosted git repository.

panxiaolei pushed a commit to branch tpc_preview3
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/tpc_preview3 by this push:
     new 30ac8e0c3a7 reduce cast of input arg from decimal avg
30ac8e0c3a7 is described below

commit 30ac8e0c3a72f6d431087a10ef5a9fa9291742d4
Author: BiteTheDDDDt <[email protected]>
AuthorDate: Wed Dec 3 14:52:26 2025 +0800

    reduce cast of input arg from decimal avg
    
    update
    
    fix
    
    fix
    
    fix
    
    fix
    
    fix
---
 .../aggregate_functions/aggregate_function_avg.h   | 85 ++++++++++++++--------
 .../trees/expressions/functions/agg/Avg.java       | 16 +---
 2 files changed, 55 insertions(+), 46 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h 
b/be/src/vec/aggregate_functions/aggregate_function_avg.h
index cf65b41d90f..ed62cd9c691 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h
@@ -55,24 +55,13 @@ template <PrimitiveType T>
 struct AggregateFunctionAvgData {
     using ResultType = typename PrimitiveTypeTraits<T>::ColumnItemType;
     static constexpr PrimitiveType ResultPType = T;
-    typename PrimitiveTypeTraits<T>::ColumnItemType sum {};
+    ResultType sum {};
     UInt64 count = 0;
 
-    AggregateFunctionAvgData& operator=(const AggregateFunctionAvgData<T>& 
src) {
-        sum = src.sum;
-        count = src.count;
-        return *this;
-    }
+    AggregateFunctionAvgData& operator=(const AggregateFunctionAvgData<T>& 
src) = default;
 
     template <typename ResultT>
-    ResultT result() const {
-        if constexpr (std::is_floating_point_v<ResultT>) {
-            if constexpr (std::numeric_limits<ResultT>::is_iec559) {
-                return static_cast<ResultT>(sum) /
-                       static_cast<ResultT>(count); /// allow division by zero
-            }
-        }
-
+    ResultT result(ResultType multiplier) const {
         if (!count) {
             // null is handled in AggregationNode::_get_without_key_result
             return static_cast<ResultT>(sum);
@@ -80,18 +69,34 @@ struct AggregateFunctionAvgData {
         // to keep the same result with row vesion; see 
AggregateFunctions::decimalv2_avg_get_value
         if constexpr (T == TYPE_DECIMALV2 && IsDecimalV2<ResultT>) {
             DecimalV2Value decimal_val_count(count, 0);
-            DecimalV2Value decimal_val_sum(sum);
+            DecimalV2Value decimal_val_sum(sum * multiplier);
             DecimalV2Value cal_ret = decimal_val_sum / decimal_val_count;
             Decimal128V2 ret(cal_ret.value());
             return ret;
         } else {
             if constexpr (T == TYPE_DECIMAL256) {
-                return static_cast<ResultT>(sum /
+                return static_cast<ResultT>(sum * multiplier /
                                             typename 
PrimitiveTypeTraits<T>::ColumnItemType(count));
             } else {
-                return static_cast<ResultT>(sum) / static_cast<ResultT>(count);
+                return static_cast<ResultT>(sum * multiplier) / 
static_cast<ResultT>(count);
+            }
+        }
+    }
+
+    template <typename ResultT>
+    ResultT result() const {
+        if constexpr (std::is_floating_point_v<ResultT>) {
+            if constexpr (std::numeric_limits<ResultT>::is_iec559) {
+                return static_cast<ResultT>(sum) /
+                       static_cast<ResultT>(count); /// allow division by zero
             }
         }
+
+        if (!count) {
+            // null is handled in AggregationNode::_get_without_key_result
+            return static_cast<ResultT>(sum);
+        }
+        return static_cast<ResultT>(sum) / static_cast<ResultT>(count);
     }
 
     void write(BufferWritable& buf) const {
@@ -112,31 +117,41 @@ class AggregateFunctionAvg final
           UnaryExpression,
           NullableAggregateFunction {
 public:
-    using ResultType = std::conditional_t<
-            T == TYPE_DECIMALV2, Decimal128V2,
-            std::conditional_t<is_decimal(T), typename Data::ResultType, 
Float64>>;
-    using ResultDataType = std::conditional_t<
-            T == TYPE_DECIMALV2, DataTypeDecimalV2,
-            std::conditional_t<is_decimal(T), 
DataTypeDecimal<Data::ResultPType>, DataTypeFloat64>>;
+    using ResultDataType =
+            std::conditional_t<is_decimal(T),
+                               typename 
PrimitiveTypeTraits<Data::ResultPType>::DataType,
+                               DataTypeFloat64>;
     using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
-    using ColVecResult = std::conditional_t<
-            T == TYPE_DECIMALV2, ColumnDecimal128V2,
-            std::conditional_t<is_decimal(T), 
ColumnDecimal<Data::ResultPType>, ColumnFloat64>>;
+    using ColVecResult =
+            std::conditional_t<is_decimal(T),
+                               typename 
PrimitiveTypeTraits<Data::ResultPType>::ColumnType,
+                               ColumnFloat64>;
     // The result calculated by PercentileApprox is an approximate value,
     // so the underlying storage uses float. The following calls will involve
     // an implicit cast to float.
 
     using DataType = typename Data::ResultType;
+    using ResultType = std::conditional_t<is_decimal(T), DataType, Float64>;
     /// ctor for native types
+    // consistent with 
fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java
+    static constexpr uint32_t DEFAULT_MIN_AVG_DECIMAL128_SCALE = 4;
     AggregateFunctionAvg(const DataTypes& argument_types_)
             : IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, 
Data>>(argument_types_),
-              scale(get_decimal_scale(*argument_types_[0])) {}
+              input_scale(get_decimal_scale(*argument_types_[0])),
+              output_scale(std::max(DEFAULT_MIN_AVG_DECIMAL128_SCALE, 
input_scale)) {
+        if constexpr (is_decimal(T)) {
+            multiplier =
+                    
ResultType(ResultDataType::get_scale_multiplier(output_scale - input_scale));
+        }
+    }
 
     String get_name() const override { return "avg"; }
 
     DataTypePtr get_return_type() const override {
         if constexpr (is_decimal(T)) {
-            return 
std::make_shared<ResultDataType>(ResultDataType::max_precision(), scale);
+            return std::make_shared<ResultDataType>(
+                    ResultDataType::max_precision(),
+                    std::max(DEFAULT_MIN_AVG_DECIMAL128_SCALE, output_scale));
         } else {
             return std::make_shared<ResultDataType>();
         }
@@ -152,14 +167,14 @@ public:
                 assert_cast<const ColVecType&, 
TypeCheckOnRelease::DISABLE>(*columns[0]);
         if constexpr (is_add) {
             if constexpr (is_decimal(T)) {
-                this->data(place).sum += 
(DataType)column.get_data()[row_num].value;
+                this->data(place).sum += column.get_data()[row_num].value;
             } else {
                 this->data(place).sum += (DataType)column.get_data()[row_num];
             }
             ++this->data(place).count;
         } else {
             if constexpr (is_decimal(T)) {
-                this->data(place).sum -= 
(DataType)column.get_data()[row_num].value;
+                this->data(place).sum -= column.get_data()[row_num].value;
             } else {
                 this->data(place).sum -= (DataType)column.get_data()[row_num];
             }
@@ -198,7 +213,11 @@ public:
 
     void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
         auto& column = assert_cast<ColVecResult&>(to);
-        column.get_data().push_back(this->data(place).template 
result<ResultType>());
+        if constexpr (is_decimal(T)) {
+            column.get_data().push_back(this->data(place).template 
result<ResultType>(multiplier));
+        } else {
+            column.get_data().push_back(this->data(place).template 
result<ResultType>());
+        }
     }
 
     void deserialize_from_column(AggregateDataPtr places, const IColumn& 
column, Arena&,
@@ -346,7 +365,9 @@ public:
     }
 
 private:
-    UInt32 scale;
+    uint32_t input_scale;
+    uint32_t output_scale;
+    ResultType multiplier;
 };
 
 } // namespace doris::vectorized
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
index d0a512027cc..0c7e0414bb3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
@@ -103,28 +103,16 @@ public class Avg extends NullableAggregateFunction
             }
             DecimalV3Type decimalV3Type = DecimalV3Type.forType(argumentType);
             // DecimalV3 scale lower than DEFAULT_MIN_AVG_DECIMAL128_SCALE 
should do cast
-            int precision = decimalV3Type.getPrecision();
             int scale = decimalV3Type.getScale();
             if (decimalV3Type.getScale() < 
ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE) {
                 scale = ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE;
-                precision = precision - decimalV3Type.getScale() + scale;
-                if (enableDecimal256) {
-                    if (precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) {
-                        precision = DecimalV3Type.MAX_DECIMAL256_PRECISION;
-                    }
-                } else {
-                    if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
-                        precision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
-                    }
-                }
             }
-            decimalV3Type = DecimalV3Type.createDecimalV3Type(precision, 
scale);
             return signature.withArgumentType(0, decimalV3Type)
                     .withReturnType(DecimalV3Type.createDecimalV3Type(
                             enableDecimal256 ? 
DecimalV3Type.MAX_DECIMAL256_PRECISION
                                     : DecimalV3Type.MAX_DECIMAL128_PRECISION,
-                            decimalV3Type.getScale()
-            ));
+                            scale)
+                    );
         } else {
             return signature;
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to