This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch tpc_preview3
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/tpc_preview3 by this push:
new 30ac8e0c3a7 reduce cast of input arg from decimal avg
30ac8e0c3a7 is described below
commit 30ac8e0c3a72f6d431087a10ef5a9fa9291742d4
Author: BiteTheDDDDt <[email protected]>
AuthorDate: Wed Dec 3 14:52:26 2025 +0800
reduce cast of input arg from decimal avg
update
fix
fix
fix
fix
fix
---
.../aggregate_functions/aggregate_function_avg.h | 85 ++++++++++++++--------
.../trees/expressions/functions/agg/Avg.java | 16 +---
2 files changed, 55 insertions(+), 46 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h
b/be/src/vec/aggregate_functions/aggregate_function_avg.h
index cf65b41d90f..ed62cd9c691 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h
@@ -55,24 +55,13 @@ template <PrimitiveType T>
struct AggregateFunctionAvgData {
using ResultType = typename PrimitiveTypeTraits<T>::ColumnItemType;
static constexpr PrimitiveType ResultPType = T;
- typename PrimitiveTypeTraits<T>::ColumnItemType sum {};
+ ResultType sum {};
UInt64 count = 0;
- AggregateFunctionAvgData& operator=(const AggregateFunctionAvgData<T>&
src) {
- sum = src.sum;
- count = src.count;
- return *this;
- }
+ AggregateFunctionAvgData& operator=(const AggregateFunctionAvgData<T>&
src) = default;
template <typename ResultT>
- ResultT result() const {
- if constexpr (std::is_floating_point_v<ResultT>) {
- if constexpr (std::numeric_limits<ResultT>::is_iec559) {
- return static_cast<ResultT>(sum) /
- static_cast<ResultT>(count); /// allow division by zero
- }
- }
-
+ ResultT result(ResultType multiplier) const {
if (!count) {
// null is handled in AggregationNode::_get_without_key_result
return static_cast<ResultT>(sum);
@@ -80,18 +69,34 @@ struct AggregateFunctionAvgData {
// to keep the same result with row vesion; see
AggregateFunctions::decimalv2_avg_get_value
if constexpr (T == TYPE_DECIMALV2 && IsDecimalV2<ResultT>) {
DecimalV2Value decimal_val_count(count, 0);
- DecimalV2Value decimal_val_sum(sum);
+ DecimalV2Value decimal_val_sum(sum * multiplier);
DecimalV2Value cal_ret = decimal_val_sum / decimal_val_count;
Decimal128V2 ret(cal_ret.value());
return ret;
} else {
if constexpr (T == TYPE_DECIMAL256) {
- return static_cast<ResultT>(sum /
+ return static_cast<ResultT>(sum * multiplier /
typename
PrimitiveTypeTraits<T>::ColumnItemType(count));
} else {
- return static_cast<ResultT>(sum) / static_cast<ResultT>(count);
+ return static_cast<ResultT>(sum * multiplier) /
static_cast<ResultT>(count);
+ }
+ }
+ }
+
+ template <typename ResultT>
+ ResultT result() const {
+ if constexpr (std::is_floating_point_v<ResultT>) {
+ if constexpr (std::numeric_limits<ResultT>::is_iec559) {
+ return static_cast<ResultT>(sum) /
+ static_cast<ResultT>(count); /// allow division by zero
}
}
+
+ if (!count) {
+ // null is handled in AggregationNode::_get_without_key_result
+ return static_cast<ResultT>(sum);
+ }
+ return static_cast<ResultT>(sum) / static_cast<ResultT>(count);
}
void write(BufferWritable& buf) const {
@@ -112,31 +117,41 @@ class AggregateFunctionAvg final
UnaryExpression,
NullableAggregateFunction {
public:
- using ResultType = std::conditional_t<
- T == TYPE_DECIMALV2, Decimal128V2,
- std::conditional_t<is_decimal(T), typename Data::ResultType,
Float64>>;
- using ResultDataType = std::conditional_t<
- T == TYPE_DECIMALV2, DataTypeDecimalV2,
- std::conditional_t<is_decimal(T),
DataTypeDecimal<Data::ResultPType>, DataTypeFloat64>>;
+ using ResultDataType =
+ std::conditional_t<is_decimal(T),
+ typename
PrimitiveTypeTraits<Data::ResultPType>::DataType,
+ DataTypeFloat64>;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
- using ColVecResult = std::conditional_t<
- T == TYPE_DECIMALV2, ColumnDecimal128V2,
- std::conditional_t<is_decimal(T),
ColumnDecimal<Data::ResultPType>, ColumnFloat64>>;
+ using ColVecResult =
+ std::conditional_t<is_decimal(T),
+ typename
PrimitiveTypeTraits<Data::ResultPType>::ColumnType,
+ ColumnFloat64>;
// The result calculated by PercentileApprox is an approximate value,
// so the underlying storage uses float. The following calls will involve
// an implicit cast to float.
using DataType = typename Data::ResultType;
+ using ResultType = std::conditional_t<is_decimal(T), DataType, Float64>;
/// ctor for native types
+ // consistent with
fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java
+ static constexpr uint32_t DEFAULT_MIN_AVG_DECIMAL128_SCALE = 4;
AggregateFunctionAvg(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T,
Data>>(argument_types_),
- scale(get_decimal_scale(*argument_types_[0])) {}
+ input_scale(get_decimal_scale(*argument_types_[0])),
+ output_scale(std::max(DEFAULT_MIN_AVG_DECIMAL128_SCALE,
input_scale)) {
+ if constexpr (is_decimal(T)) {
+ multiplier =
+
ResultType(ResultDataType::get_scale_multiplier(output_scale - input_scale));
+ }
+ }
String get_name() const override { return "avg"; }
DataTypePtr get_return_type() const override {
if constexpr (is_decimal(T)) {
- return
std::make_shared<ResultDataType>(ResultDataType::max_precision(), scale);
+ return std::make_shared<ResultDataType>(
+ ResultDataType::max_precision(),
+ std::max(DEFAULT_MIN_AVG_DECIMAL128_SCALE, output_scale));
} else {
return std::make_shared<ResultDataType>();
}
@@ -152,14 +167,14 @@ public:
assert_cast<const ColVecType&,
TypeCheckOnRelease::DISABLE>(*columns[0]);
if constexpr (is_add) {
if constexpr (is_decimal(T)) {
- this->data(place).sum +=
(DataType)column.get_data()[row_num].value;
+ this->data(place).sum += column.get_data()[row_num].value;
} else {
this->data(place).sum += (DataType)column.get_data()[row_num];
}
++this->data(place).count;
} else {
if constexpr (is_decimal(T)) {
- this->data(place).sum -=
(DataType)column.get_data()[row_num].value;
+ this->data(place).sum -= column.get_data()[row_num].value;
} else {
this->data(place).sum -= (DataType)column.get_data()[row_num];
}
@@ -198,7 +213,11 @@ public:
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
auto& column = assert_cast<ColVecResult&>(to);
- column.get_data().push_back(this->data(place).template
result<ResultType>());
+ if constexpr (is_decimal(T)) {
+ column.get_data().push_back(this->data(place).template
result<ResultType>(multiplier));
+ } else {
+ column.get_data().push_back(this->data(place).template
result<ResultType>());
+ }
}
void deserialize_from_column(AggregateDataPtr places, const IColumn&
column, Arena&,
@@ -346,7 +365,9 @@ public:
}
private:
- UInt32 scale;
+ uint32_t input_scale;
+ uint32_t output_scale;
+ ResultType multiplier;
};
} // namespace doris::vectorized
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
index d0a512027cc..0c7e0414bb3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
@@ -103,28 +103,16 @@ public class Avg extends NullableAggregateFunction
}
DecimalV3Type decimalV3Type = DecimalV3Type.forType(argumentType);
// DecimalV3 scale lower than DEFAULT_MIN_AVG_DECIMAL128_SCALE
should do cast
- int precision = decimalV3Type.getPrecision();
int scale = decimalV3Type.getScale();
if (decimalV3Type.getScale() <
ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE) {
scale = ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE;
- precision = precision - decimalV3Type.getScale() + scale;
- if (enableDecimal256) {
- if (precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) {
- precision = DecimalV3Type.MAX_DECIMAL256_PRECISION;
- }
- } else {
- if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
- precision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
- }
- }
}
- decimalV3Type = DecimalV3Type.createDecimalV3Type(precision,
scale);
return signature.withArgumentType(0, decimalV3Type)
.withReturnType(DecimalV3Type.createDecimalV3Type(
enableDecimal256 ?
DecimalV3Type.MAX_DECIMAL256_PRECISION
: DecimalV3Type.MAX_DECIMAL128_PRECISION,
- decimalV3Type.getScale()
- ));
+ scale)
+ );
} else {
return signature;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]