This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new f58a071605 [Bug][Function] pass intermediate argument list to be
(#10650)
f58a071605 is described below
commit f58a071605a1aaa7a68a99cdd2f098a5868787e4
Author: Pxl <[email protected]>
AuthorDate: Fri Jul 8 20:50:05 2022 +0800
[Bug][Function] pass intermediate argument list to be (#10650)
---
.../aggregate_function_orthogonal_bitmap.cpp | 2 --
.../aggregate_function_topn.cpp | 5 +----
.../aggregate_functions/aggregate_function_topn.h | 8 --------
be/src/vec/data_types/data_type_factory.hpp | 4 ++++
be/src/vec/exprs/vectorized_agg_fn.cpp | 22 ++++++++++------------
be/src/vec/exprs/vectorized_agg_fn.h | 2 +-
.../org/apache/doris/analysis/AggregateInfo.java | 15 ++++-----------
.../apache/doris/analysis/FunctionCallExpr.java | 20 +++++++++++++++-----
.../org/apache/doris/analysis/FunctionParams.java | 15 +++++++++++++++
gensrc/thrift/Exprs.thrift | 1 +
gensrc/thrift/Types.thrift | 1 +
11 files changed, 52 insertions(+), 43 deletions(-)
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
index 470a6c8388..9794a72090 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
@@ -34,8 +34,6 @@ AggregateFunctionPtr
create_aggregate_function_orthogonal(const std::string& nam
LOG(WARNING) << "Incorrect number of arguments for aggregate function
" << name;
return nullptr;
} else if (argument_types.size() == 1) {
- // only used at AGGREGATE (merge finalize) for variadic function
- // and for orthogonal_bitmap_union_count function
return
std::make_shared<AggFunctionOrthBitmapFunc<Impl<StringValue>>>(argument_types);
} else {
const IDataType& argument_type = *argument_types[1];
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
index 04df93ce67..19f52fbff8 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
@@ -23,10 +23,7 @@ AggregateFunctionPtr create_aggregate_function_topn(const
std::string& name,
const DataTypes&
argument_types,
const Array& parameters,
const bool
result_is_nullable) {
- if (argument_types.size() == 1) {
- return AggregateFunctionPtr(
- new
AggregateFunctionTopN<AggregateFunctionTopNImplEmpty>(argument_types));
- } else if (argument_types.size() == 2) {
+ if (argument_types.size() == 2) {
return AggregateFunctionPtr(
new
AggregateFunctionTopN<AggregateFunctionTopNImplInt<StringDataImplTopN>>(
argument_types));
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h
b/be/src/vec/aggregate_functions/aggregate_function_topn.h
index 97ac5c7cba..ae9fdf322d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h
@@ -168,14 +168,6 @@ struct StringDataImplTopN {
}
};
-struct AggregateFunctionTopNImplEmpty {
- // only used at AGGREGATE (merge finalize)
- static void add(AggregateFunctionTopNData& __restrict place, const
IColumn** columns,
- size_t row_num) {
- LOG(FATAL) << "AggregateFunctionTopNImplEmpty do not support add()";
- }
-};
-
template <typename DataHelper>
struct AggregateFunctionTopNImplInt {
static void add(AggregateFunctionTopNData& __restrict place, const
IColumn** columns,
diff --git a/be/src/vec/data_types/data_type_factory.hpp
b/be/src/vec/data_types/data_type_factory.hpp
index 08dc6a9f31..59740debd3 100644
--- a/be/src/vec/data_types/data_type_factory.hpp
+++ b/be/src/vec/data_types/data_type_factory.hpp
@@ -102,6 +102,10 @@ public:
DataTypePtr create_data_type(const arrow::DataType* type, bool
is_nullable);
+ DataTypePtr create_data_type(const TTypeDesc& raw_type) {
+ return create_data_type(TypeDescriptor::from_thrift(raw_type),
raw_type.is_nullable);
+ }
+
private:
DataTypePtr _create_primitive_data_type(const FieldType& type) const;
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index ad7066a9a4..b7e14817f1 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -33,7 +33,6 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
: _fn(desc.fn),
_is_merge(desc.agg_expr.is_merge_agg),
_return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)),
-
_intermediate_type(TypeDescriptor::from_thrift(desc.fn.aggregate_fn.intermediate_type)),
_intermediate_slot_desc(nullptr),
_output_slot_desc(nullptr),
_exec_timer(nullptr),
@@ -44,6 +43,11 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
nullable = desc.is_nullable;
}
_data_type = DataTypeFactory::instance().create_data_type(_return_type,
nullable);
+
+ auto& param_types = desc.agg_expr.param_types;
+ for (auto raw_type : param_types) {
+
_argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type));
+ }
}
Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc,
AggFnEvaluator** result) {
@@ -55,7 +59,7 @@ Status AggFnEvaluator::create(ObjectPool* pool, const TExpr&
desc, AggFnEvaluato
VExpr* expr = nullptr;
VExprContext* ctx = nullptr;
RETURN_IF_ERROR(
- VExpr::create_tree_from_thrift(pool, desc.nodes, NULL,
&node_idx, &expr, &ctx));
+ VExpr::create_tree_from_thrift(pool, desc.nodes, nullptr,
&node_idx, &expr, &ctx));
agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx);
}
return Status::OK();
@@ -65,25 +69,19 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const
RowDescriptor& desc, M
const SlotDescriptor* intermediate_slot_desc,
const SlotDescriptor* output_slot_desc,
const std::shared_ptr<MemTracker>& mem_tracker)
{
- DCHECK(pool != NULL);
- DCHECK(intermediate_slot_desc != NULL);
- DCHECK(_intermediate_slot_desc == NULL);
+ DCHECK(pool != nullptr);
+ DCHECK(intermediate_slot_desc != nullptr);
+ DCHECK(_intermediate_slot_desc == nullptr);
_output_slot_desc = output_slot_desc;
_intermediate_slot_desc = intermediate_slot_desc;
Status status = VExpr::prepare(_input_exprs_ctxs, state, desc,
mem_tracker);
RETURN_IF_ERROR(status);
- DataTypes argument_types;
- argument_types.reserve(_input_exprs_ctxs.size());
-
std::vector<std::string_view> child_expr_name;
- doris::vectorized::Array params;
// prepare for argument
for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
- auto data_type = _input_exprs_ctxs[i]->root()->data_type();
- argument_types.emplace_back(data_type);
child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name());
}
@@ -95,7 +93,7 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const
RowDescriptor& desc, M
#endif
} else {
_function = AggregateFunctionSimpleFactory::instance().get(
- _fn.name.function_name, argument_types, params,
_data_type->is_nullable());
+ _fn.name.function_name, _argument_types, {},
_data_type->is_nullable());
}
if (_function == nullptr) {
return Status::InternalError("Agg Function {} is not implemented",
_fn.name.function_name);
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h
b/be/src/vec/exprs/vectorized_agg_fn.h
index 0f1f145ced..9a1dbafdcc 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -78,8 +78,8 @@ private:
void _calc_argment_columns(Block* block);
+ DataTypes _argument_types;
const TypeDescriptor _return_type;
- const TypeDescriptor _intermediate_type;
const SlotDescriptor* _intermediate_slot_desc;
const SlotDescriptor* _output_slot_desc;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index a855b53de2..a0152a8e8a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -491,17 +491,10 @@ public final class AggregateInfo extends
AggregateInfoBase {
for (int i = 0; i < getAggregateExprs().size(); ++i) {
FunctionCallExpr inputExpr = getAggregateExprs().get(i);
Preconditions.checkState(inputExpr.isAggregateFunction());
- List<Expr> paramExprs = new ArrayList<>();
- // TODO(zhannngchen), change intermediate argument to a list, and
remove this
- // ad-hoc logic
- if (inputExpr.fn.functionName().equals("max_by")
- || inputExpr.fn.functionName().equals("min_by")) {
- paramExprs.addAll(inputExpr.getFnParams().exprs());
- } else {
- paramExprs.add(new SlotRef(inputDesc.getSlots().get(i +
getGroupingExprs().size())));
- }
+ Expr aggExprParam =
+ new SlotRef(inputDesc.getSlots().get(i +
getGroupingExprs().size()));
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
- inputExpr, paramExprs);
+ inputExpr, Lists.newArrayList(aggExprParam),
inputExpr.getFnParams().exprs());
aggExpr.analyzeNoThrow(analyzer);
aggExprs.add(aggExpr);
}
@@ -623,7 +616,7 @@ public final class AggregateInfo extends AggregateInfoBase {
Expr aggExprParam =
new SlotRef(inputDesc.getSlots().get(i +
getGroupingExprs().size()));
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
- inputExpr, Lists.newArrayList(aggExprParam));
+ inputExpr, Lists.newArrayList(aggExprParam),
inputExpr.getFnParams().exprs());
secondPhaseAggExprs.add(aggExpr);
}
Preconditions.checkState(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index 637bf6c158..543c74776f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -37,7 +37,6 @@ import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.util.VectorizedUtil;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext;
-import org.apache.doris.thrift.TAggregateExpr;
import org.apache.doris.thrift.TExprNode;
import org.apache.doris.thrift.TExprNodeType;
@@ -69,6 +68,9 @@ public class FunctionCallExpr extends Expr {
// private BuiltinAggregateFunction.Operator aggOp;
private FunctionParams fnParams;
+ // represent original parament from aggregate function
+ private FunctionParams aggFnParams;
+
// check analytic function
private boolean isAnalyticFnCall = false;
// check table function
@@ -92,6 +94,10 @@ public class FunctionCallExpr extends Expr {
private boolean isRewrote = false;
+ public void setAggFnParams(FunctionParams aggFnParams) {
+ this.aggFnParams = aggFnParams;
+ }
+
public void setIsAnalyticFnCall(boolean v) {
isAnalyticFnCall = v;
}
@@ -153,6 +159,7 @@ public class FunctionCallExpr extends Expr {
// aggOp = e.aggOp;
isAnalyticFnCall = e.isAnalyticFnCall;
fnParams = params;
+ aggFnParams = e.aggFnParams;
// Just inherit the function object from 'e'.
fn = e.fn;
this.isMergeAggFn = e.isMergeAggFn;
@@ -175,6 +182,7 @@ public class FunctionCallExpr extends Expr {
} else {
fnParams = new FunctionParams(other.fnParams.isDistinct(),
children);
}
+ aggFnParams = other.aggFnParams;
this.isMergeAggFn = other.isMergeAggFn;
fn = other.fn;
this.isTableFnCall = other.isTableFnCall;
@@ -428,9 +436,10 @@ public class FunctionCallExpr extends Expr {
// except in test cases that do it explicitly.
if (isAggregate() || isAnalyticFnCall) {
msg.node_type = TExprNodeType.AGG_EXPR;
- if (!isAnalyticFnCall) {
- msg.setAggExpr(new TAggregateExpr(isMergeAggFn));
+ if (aggFnParams == null) {
+ aggFnParams = fnParams;
}
+ msg.setAggExpr(aggFnParams.createTAggregateExpr(isMergeAggFn));
} else {
msg.node_type = TExprNodeType.FUNCTION_CALL;
}
@@ -1143,14 +1152,15 @@ public class FunctionCallExpr extends Expr {
}
public static FunctionCallExpr createMergeAggCall(
- FunctionCallExpr agg, List<Expr> params) {
+ FunctionCallExpr agg, List<Expr> intermediateParams, List<Expr>
realParams) {
Preconditions.checkState(agg.isAnalyzed);
Preconditions.checkState(agg.isAggregateFunction());
FunctionCallExpr result = new FunctionCallExpr(
- agg.fnName, new FunctionParams(false, params), true);
+ agg.fnName, new FunctionParams(false, intermediateParams),
true);
// Inherit the function object from 'agg'.
result.fn = agg.fn;
result.type = agg.type;
+ result.setAggFnParams(new FunctionParams(false, realParams));
return result;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
index 32cfba0351..3b77ec52b6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
@@ -21,12 +21,15 @@
package org.apache.doris.analysis;
import org.apache.doris.common.io.Writable;
+import org.apache.doris.thrift.TAggregateExpr;
+import org.apache.doris.thrift.TTypeDesc;
import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
@@ -62,6 +65,18 @@ public class FunctionParams implements Writable {
return new FunctionParams();
}
+ public TAggregateExpr createTAggregateExpr(boolean isMergeAggFn) {
+ List<TTypeDesc> paramTypes = new ArrayList<TTypeDesc>();
+ if (exprs != null) {
+ for (Expr expr : exprs) {
+ TTypeDesc desc = expr.getType().toThrift();
+ desc.setIsNullable(expr.isNullable());
+ paramTypes.add(desc);
+ }
+ }
+ return new TAggregateExpr(isMergeAggFn, paramTypes);
+ }
+
public boolean isStar() {
return isStar;
}
diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift
index 450148f381..50c9119410 100644
--- a/gensrc/thrift/Exprs.thrift
+++ b/gensrc/thrift/Exprs.thrift
@@ -73,6 +73,7 @@ enum TExprNodeType {
struct TAggregateExpr {
// Indicates whether this expr is the merge() of an aggregation.
1: required bool is_merge_agg
+ 2: required list<Types.TTypeDesc> param_types
}
struct TBoolLiteral {
1: required bool value
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 381f2879e9..a7212e0476 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -145,6 +145,7 @@ struct TTypeNode {
// to TTypeDesc. In future, we merge these two to one
struct TTypeDesc {
1: list<TTypeNode> types
+ 2: optional bool is_nullable
}
enum TAggregationType {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]