This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch dev-1.1.2
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/dev-1.1.2 by this push:
new ccf2f82be4 [fix](function)fix max_by function bug (#11745)
ccf2f82be4 is described below
commit ccf2f82be43d9170cb0c3373bd27338573e6d7c9
Author: starocean999 <[email protected]>
AuthorDate: Mon Aug 15 09:06:54 2022 +0800
[fix](function)fix max_by function bug (#11745)
This pr does the same thing as #10650. Because the code base is so
different that it's easier to make the changes based on dev-1.1.2 than
cherry-pick
---
be/src/vec/CMakeLists.txt | 1 +
be/src/vec/data_types/data_type_factory.cpp | 100 +++++++++++++++++++++
be/src/vec/data_types/data_type_factory.hpp | 8 +-
be/src/vec/exprs/vectorized_agg_fn.cpp | 28 +++---
be/src/vec/exprs/vectorized_agg_fn.h | 2 +-
.../org/apache/doris/analysis/AggregateInfo.java | 18 ++--
.../apache/doris/analysis/FunctionCallExpr.java | 20 ++++-
.../org/apache/doris/analysis/FunctionParams.java | 17 ++++
gensrc/thrift/Exprs.thrift | 1 +
gensrc/thrift/Types.thrift | 1 +
10 files changed, 166 insertions(+), 30 deletions(-)
diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index d8adfe6d40..56cfdfcc14 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -73,6 +73,7 @@ set(VEC_FILES
data_types/nested_utils.cpp
data_types/data_type_date.cpp
data_types/data_type_date_time.cpp
+ data_types/data_type_factory.cpp
exec/vaggregation_node.cpp
exec/ves_http_scan_node.cpp
exec/ves_http_scanner.cpp
diff --git a/be/src/vec/data_types/data_type_factory.cpp
b/be/src/vec/data_types/data_type_factory.cpp
new file mode 100644
index 0000000000..d78679d377
--- /dev/null
+++ b/be/src/vec/data_types/data_type_factory.cpp
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/DataTypes/DataTypeFactory.cpp
+// and modified by Doris
+
+#include "vec/data_types/data_type_factory.hpp"
+#include "runtime/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_bitmap.h"
+#include "vec/data_types/data_type_date.h"
+#include "vec/data_types/data_type_date_time.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nothing.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+
+namespace doris::vectorized {
+
+DataTypePtr DataTypeFactory::create_data_type(const TypeDescriptor& col_desc,
bool is_nullable) {
+ DataTypePtr nested = nullptr;
+ switch (col_desc.type) {
+ case TYPE_BOOLEAN:
+ nested = std::make_shared<vectorized::DataTypeUInt8>();
+ break;
+ case TYPE_TINYINT:
+ nested = std::make_shared<vectorized::DataTypeInt8>();
+ break;
+ case TYPE_SMALLINT:
+ nested = std::make_shared<vectorized::DataTypeInt16>();
+ break;
+ case TYPE_INT:
+ nested = std::make_shared<vectorized::DataTypeInt32>();
+ break;
+ case TYPE_FLOAT:
+ nested = std::make_shared<vectorized::DataTypeFloat32>();
+ break;
+ case TYPE_BIGINT:
+ nested = std::make_shared<vectorized::DataTypeInt64>();
+ break;
+ case TYPE_LARGEINT:
+ nested = std::make_shared<vectorized::DataTypeInt128>();
+ break;
+ case TYPE_DATE:
+ nested = std::make_shared<vectorized::DataTypeDate>();
+ break;
+ case TYPE_DATETIME:
+ nested = std::make_shared<vectorized::DataTypeDateTime>();
+ break;
+ case TYPE_TIME:
+ case TYPE_DOUBLE:
+ nested = std::make_shared<vectorized::DataTypeFloat64>();
+ break;
+ case TYPE_STRING:
+ case TYPE_CHAR:
+ case TYPE_VARCHAR:
+ nested = std::make_shared<vectorized::DataTypeString>();
+ break;
+ case TYPE_HLL:
+ nested = std::make_shared<vectorized::DataTypeHLL>();
+ break;
+ case TYPE_OBJECT:
+ nested = std::make_shared<vectorized::DataTypeBitMap>();
+ break;
+ case TYPE_DECIMALV2:
+ nested =
std::make_shared<vectorized::DataTypeDecimal<vectorized::Decimal128>>(27, 9);
+ break;
+ // Just Mock A NULL Type in Vec Exec Engine
+ case TYPE_NULL:
+ nested = std::make_shared<vectorized::DataTypeUInt8>();
+ break;
+ case INVALID_TYPE:
+ default:
+ DCHECK(false) << "invalid PrimitiveType:" << (int)col_desc.type;
+ break;
+ }
+
+ if (nested && is_nullable) {
+ return std::make_shared<vectorized::DataTypeNullable>(nested);
+ }
+ return nested;
+}
+
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/data_types/data_type_factory.hpp
b/be/src/vec/data_types/data_type_factory.hpp
index e06a962c2f..abe84cfe52 100644
--- a/be/src/vec/data_types/data_type_factory.hpp
+++ b/be/src/vec/data_types/data_type_factory.hpp
@@ -21,7 +21,7 @@
#pragma once
#include <mutex>
#include <string>
-
+#include "runtime/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_date.h"
#include "vec/data_types/data_type_date_time.h"
@@ -74,6 +74,12 @@ public:
return _empty_string;
}
+ DataTypePtr create_data_type(const TypeDescriptor& col_desc, bool
is_nullable = true);
+
+ DataTypePtr create_data_type(const TTypeDesc& raw_type) {
+ return create_data_type(TypeDescriptor::from_thrift(raw_type),
raw_type.is_nullable);
+ }
+
private:
void regist_data_type(const std::string& name, const DataTypePtr&
data_type) {
_data_type_map.emplace(name, data_type);
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 0987190069..cab89ae7ff 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -23,6 +23,7 @@
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/columns/column_nullable.h"
#include "vec/core/materialize_block.h"
+#include "vec/data_types/data_type_factory.hpp"
#include "vec/data_types/data_type_nullable.h"
#include "vec/exprs/vexpr.h"
@@ -32,18 +33,23 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
: _fn(desc.fn),
_is_merge(desc.agg_expr.is_merge_agg),
_return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)),
-
_intermediate_type(TypeDescriptor::from_thrift(desc.fn.aggregate_fn.intermediate_type)),
_intermediate_slot_desc(nullptr),
_output_slot_desc(nullptr),
_exec_timer(nullptr),
_merge_timer(nullptr),
_expr_timer(nullptr) {
- if (desc.__isset.is_nullable) {
- _data_type = IDataType::from_thrift(_return_type.type,
desc.is_nullable);
- } else {
- _data_type = IDataType::from_thrift(_return_type.type);
+ if (desc.__isset.is_nullable) {
+ _data_type = IDataType::from_thrift(_return_type.type,
desc.is_nullable);
+ } else {
+ _data_type = IDataType::from_thrift(_return_type.type);
+ }
+ if (desc.agg_expr.__isset.param_types) {
+ auto& param_types = desc.agg_expr.param_types;
+ for (auto raw_type : param_types) {
+
_argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type));
}
}
+}
Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc,
AggFnEvaluator** result) {
*result = pool->add(new AggFnEvaluator(desc.nodes[0]));
@@ -73,21 +79,21 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const
RowDescriptor& desc, M
Status status = VExpr::prepare(_input_exprs_ctxs, state, desc,
mem_tracker);
RETURN_IF_ERROR(status);
- DataTypes argument_types;
- argument_types.reserve(_input_exprs_ctxs.size());
+ DataTypes tmp_argument_types;
+ tmp_argument_types.reserve(_input_exprs_ctxs.size());
std::vector<std::string_view> child_expr_name;
- doris::vectorized::Array params;
// prepare for argument
for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
auto data_type = _input_exprs_ctxs[i]->root()->data_type();
- argument_types.emplace_back(data_type);
+ tmp_argument_types.emplace_back(data_type);
child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name());
}
- _function =
AggregateFunctionSimpleFactory::instance().get(_fn.name.function_name,
argument_types,
- params,
_data_type->is_nullable());
+ _function = AggregateFunctionSimpleFactory::instance().get(
+ _fn.name.function_name, _argument_types.empty() ?
tmp_argument_types : _argument_types,
+ {}, _data_type->is_nullable());
if (_function == nullptr) {
return Status::InternalError(
fmt::format("Agg Function {} is not implemented",
_fn.name.function_name));
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h
b/be/src/vec/exprs/vectorized_agg_fn.h
index 0f1f145ced..ea130c94de 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -79,7 +79,7 @@ private:
void _calc_argment_columns(Block* block);
const TypeDescriptor _return_type;
- const TypeDescriptor _intermediate_type;
+ DataTypes _argument_types;
const SlotDescriptor* _intermediate_slot_desc;
const SlotDescriptor* _output_slot_desc;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index 128e60df7a..a82f2a7071 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -19,7 +19,6 @@ package org.apache.doris.analysis;
import org.apache.doris.catalog.FunctionSet;
import org.apache.doris.common.AnalysisException;
-import org.apache.doris.common.util.VectorizedUtil;
import org.apache.doris.planner.DataPartition;
import org.apache.doris.thrift.TPartitionType;
@@ -459,17 +458,10 @@ public final class AggregateInfo extends
AggregateInfoBase {
for (int i = 0; i < getAggregateExprs().size(); ++i) {
FunctionCallExpr inputExpr = getAggregateExprs().get(i);
Preconditions.checkState(inputExpr.isAggregateFunction());
- List<Expr> paramExprs = new ArrayList<>();
- // TODO(zhannngchen), change intermediate argument to a list, and
remove this
- // ad-hoc logic
- if ((inputExpr.fn.functionName().equals("max_by") ||
- inputExpr.fn.functionName().equals("min_by")) &&
VectorizedUtil.isVectorized()) {
- paramExprs.addAll(inputExpr.getFnParams().exprs());
- } else {
- paramExprs.add(new SlotRef(inputDesc.getSlots().get(i +
getGroupingExprs().size())));
- }
+ Expr aggExprParam = new SlotRef(inputDesc.getSlots().get(i +
getGroupingExprs().size()));
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
- inputExpr, paramExprs);
+ inputExpr, Lists.newArrayList(aggExprParam),
inputExpr.getFnParams().exprs());
+
aggExpr.analyzeNoThrow(analyzer);
aggExprs.add(aggExpr);
}
@@ -586,11 +578,11 @@ public final class AggregateInfo extends
AggregateInfoBase {
for (int i = 0; i < aggregateExprs_.size(); ++i) {
FunctionCallExpr inputExpr = aggregateExprs_.get(i);
Preconditions.checkState(inputExpr.isAggregateFunction());
- // we're aggregating an output slot of the 1st agg phase
Expr aggExprParam =
new SlotRef(inputDesc.getSlots().get(i +
getGroupingExprs().size()));
+
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
- inputExpr, Lists.newArrayList(aggExprParam));
+ inputExpr, Lists.newArrayList(aggExprParam),
inputExpr.getFnParams().exprs());
secondPhaseAggExprs.add(aggExpr);
}
Preconditions.checkState(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index 0021a1cf9b..58bca1cbf3 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -33,7 +33,6 @@ import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.util.VectorizedUtil;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext;
-import org.apache.doris.thrift.TAggregateExpr;
import org.apache.doris.thrift.TExprNode;
import org.apache.doris.thrift.TExprNodeType;
@@ -66,6 +65,9 @@ public class FunctionCallExpr extends Expr {
// private BuiltinAggregateFunction.Operator aggOp;
private FunctionParams fnParams;
+ // represent original parament from aggregate function
+ private FunctionParams aggFnParams;
+
// check analytic function
private boolean isAnalyticFnCall = false;
// check table function
@@ -89,6 +91,10 @@ public class FunctionCallExpr extends Expr {
private boolean isRewrote = false;
+ public void setAggFnParams(FunctionParams aggFnParams) {
+ this.aggFnParams = aggFnParams;
+ }
+
public void setIsAnalyticFnCall(boolean v) {
isAnalyticFnCall = v;
}
@@ -150,6 +156,7 @@ public class FunctionCallExpr extends Expr {
// aggOp = e.aggOp;
isAnalyticFnCall = e.isAnalyticFnCall;
fnParams = params;
+ aggFnParams = e.aggFnParams;
// Just inherit the function object from 'e'.
fn = e.fn;
this.isMergeAggFn = e.isMergeAggFn;
@@ -172,6 +179,7 @@ public class FunctionCallExpr extends Expr {
} else {
fnParams = new FunctionParams(other.fnParams.isDistinct(),
children);
}
+ aggFnParams = other.aggFnParams;
this.isMergeAggFn = other.isMergeAggFn;
fn = other.fn;
this.isTableFnCall = other.isTableFnCall;
@@ -354,7 +362,10 @@ public class FunctionCallExpr extends Expr {
if (isAggregate() || isAnalyticFnCall) {
msg.node_type = TExprNodeType.AGG_EXPR;
if (!isAnalyticFnCall) {
- msg.setAggExpr(new TAggregateExpr(isMergeAggFn));
+ if (aggFnParams == null) {
+ aggFnParams = fnParams;
+ }
+ msg.setAggExpr(aggFnParams.createTAggregateExpr(isMergeAggFn));
}
} else {
msg.node_type = TExprNodeType.FUNCTION_CALL;
@@ -1041,14 +1052,15 @@ public class FunctionCallExpr extends Expr {
}
public static FunctionCallExpr createMergeAggCall(
- FunctionCallExpr agg, List<Expr> params) {
+ FunctionCallExpr agg, List<Expr> intermediateParams, List<Expr>
realParams) {
Preconditions.checkState(agg.isAnalyzed);
Preconditions.checkState(agg.isAggregateFunction());
FunctionCallExpr result = new FunctionCallExpr(
- agg.fnName, new FunctionParams(false, params), true);
+ agg.fnName, new FunctionParams(false, intermediateParams),
true);
// Inherit the function object from 'agg'.
result.fn = agg.fn;
result.type = agg.type;
+ result.setAggFnParams(new FunctionParams(false, realParams));
return result;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
index 59d85ace26..742cbf4b4c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
@@ -18,12 +18,15 @@
package org.apache.doris.analysis;
import org.apache.doris.common.io.Writable;
+import org.apache.doris.thrift.TAggregateExpr;
+import org.apache.doris.thrift.TTypeDesc;
import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
@@ -118,4 +121,18 @@ public class FunctionParams implements Writable {
}
return result;
}
+
+ public TAggregateExpr createTAggregateExpr(boolean isMergeAggFn) {
+ List<TTypeDesc> paramTypes = new ArrayList<TTypeDesc>();
+ if (exprs != null) {
+ for (Expr expr : exprs) {
+ TTypeDesc desc = expr.getType().toThrift();
+ desc.setIsNullable(expr.isNullable());
+ paramTypes.add(desc);
+ }
+ }
+ TAggregateExpr aggExpr = new TAggregateExpr(isMergeAggFn);
+ aggExpr.setParamTypes(paramTypes);
+ return aggExpr;
+ }
}
diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift
index 450148f381..df44b5ae60 100644
--- a/gensrc/thrift/Exprs.thrift
+++ b/gensrc/thrift/Exprs.thrift
@@ -73,6 +73,7 @@ enum TExprNodeType {
struct TAggregateExpr {
// Indicates whether this expr is the merge() of an aggregation.
1: required bool is_merge_agg
+ 2: optional list<Types.TTypeDesc> param_types
}
struct TBoolLiteral {
1: required bool value
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 9a232915ee..ae5b5e1853 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -135,6 +135,7 @@ struct TTypeNode {
// to TTypeDesc. In future, we merge these two to one
struct TTypeDesc {
1: list<TTypeNode> types
+ 2: optional bool is_nullable
}
enum TAggregationType {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]