This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new dca6b8175d7 [fix](nereids) build agg for random distributed agg table
in bindRelation phase (#40181)
dca6b8175d7 is described below
commit dca6b8175d7557b0b91c7a5c97d2656b02dced6f
Author: starocean999 <[email protected]>
AuthorDate: Wed Sep 11 18:16:03 2024 +0800
[fix](nereids) build agg for random distributed agg table in bindRelation
phase (#40181)
it's better to build agg for random distributed agg table in
bindRelation phase instead of in BuildAggForRandomDistributedTable RULE
---
.../doris/nereids/jobs/executor/Analyzer.java | 3 -
.../org/apache/doris/nereids/rules/RuleType.java | 4 -
.../doris/nereids/rules/analysis/BindRelation.java | 125 +++++++++-
.../BuildAggForRandomDistributedTable.java | 271 ---------------------
.../doris/nereids/rules/analysis/CheckPolicy.java | 21 +-
.../nereids/rules/analysis/BindRelationTest.java | 23 ++
.../nereids/rules/analysis/CheckRowPolicyTest.java | 97 ++++++++
.../aggregate/select_random_distributed_tbl.out | 14 +-
.../aggregate/select_random_distributed_tbl.groovy | 19 +-
9 files changed, 285 insertions(+), 292 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index 605a848181c..1ffbac97d74 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -26,7 +26,6 @@ import org.apache.doris.nereids.rules.analysis.BindExpression;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import
org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.rules.analysis.BindSink;
-import
org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
@@ -163,8 +162,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
topDown(new EliminateGroupByConstant()),
topDown(new SimplifyAggGroupBy()),
- // run BuildAggForRandomDistributedTable before NormalizeAggregate
in order to optimize the agg plan
- topDown(new BuildAggForRandomDistributedTable()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
bottomUp(new SemiJoinCommute()),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b6ab6e2dac2..d345d9057e9 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -342,10 +342,6 @@ public enum RuleType {
// topn opts
DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE),
- // pre agg for random distributed table
- BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE),
- BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE),
- BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE),
// short circuit rule
SHOR_CIRCUIT_POINT_QUERY(RuleTypeClass.REWRITE),
// exploration rules
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
index cedc4e92ff1..c81fcc25b83 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
@@ -17,10 +17,17 @@
package org.apache.doris.nereids.rules.analysis;
+import org.apache.doris.catalog.AggStateType;
+import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.DistributionInfo;
+import org.apache.doris.catalog.Env;
+import org.apache.doris.catalog.FunctionRegistry;
+import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.TableIf;
+import org.apache.doris.catalog.Type;
import org.apache.doris.catalog.View;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
@@ -43,13 +50,26 @@ import
org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
+import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
+import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.algebra.Relation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalEsScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
@@ -73,6 +93,7 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
@@ -214,7 +235,109 @@ public class BindRelation extends OneAnalysisRuleFactory {
unboundRelation.getTableSample());
}
}
- return checkAndAddDeleteSignFilter(scan, ConnectContext.get(),
(OlapTable) table);
+ if (needGenerateLogicalAggForRandomDistAggTable(scan)) {
+ // it's a random distribution agg table
+ // add agg on olap scan
+ return preAggForRandomDistribution(scan);
+ } else {
+ // it's a duplicate, unique or hash distribution agg table
+ // add delete sign filter on olap scan if needed
+ return checkAndAddDeleteSignFilter(scan, ConnectContext.get(),
(OlapTable) table);
+ }
+ }
+
+ private boolean
needGenerateLogicalAggForRandomDistAggTable(LogicalOlapScan olapScan) {
+ if (ConnectContext.get() != null && ConnectContext.get().getState() !=
null
+ && ConnectContext.get().getState().isQuery()) {
+ // we only need to add an agg node for query, and should not do it
for deleting
+ // from random distributed table. see
https://github.com/apache/doris/pull/37985 for more info
+ OlapTable olapTable = olapScan.getTable();
+ KeysType keysType = olapTable.getKeysType();
+ DistributionInfo distributionInfo =
olapTable.getDefaultDistributionInfo();
+ return keysType == KeysType.AGG_KEYS
+ && distributionInfo.getType() ==
DistributionInfo.DistributionInfoType.RANDOM;
+ } else {
+ return false;
+ }
+ }
+
+ /**
+ * add LogicalAggregate above olapScan for preAgg
+ * @param olapScan olap scan plan
+ * @return rewritten plan
+ */
+ private LogicalPlan preAggForRandomDistribution(LogicalOlapScan olapScan) {
+ OlapTable olapTable = olapScan.getTable();
+ List<Slot> childOutputSlots = olapScan.computeOutput();
+ List<Expression> groupByExpressions = new ArrayList<>();
+ List<NamedExpression> outputExpressions = new ArrayList<>();
+ List<Column> columns = olapTable.getBaseSchema();
+
+ for (Column col : columns) {
+ // use exist slot in the plan
+ SlotReference slot = SlotReference.fromColumn(olapTable, col,
col.getName(), olapScan.qualified());
+ ExprId exprId = slot.getExprId();
+ for (Slot childSlot : childOutputSlots) {
+ if (childSlot instanceof SlotReference && ((SlotReference)
childSlot).getName() == col.getName()) {
+ exprId = childSlot.getExprId();
+ slot = slot.withExprId(exprId);
+ break;
+ }
+ }
+ if (col.isKey()) {
+ groupByExpressions.add(slot);
+ outputExpressions.add(slot);
+ } else {
+ Expression function = generateAggFunction(slot, col);
+ // DO NOT rewrite
+ if (function == null) {
+ return olapScan;
+ }
+ Alias alias = new Alias(exprId, ImmutableList.of(function),
col.getName(),
+ olapScan.qualified(), true);
+ outputExpressions.add(alias);
+ }
+ }
+ LogicalAggregate<LogicalOlapScan> aggregate = new
LogicalAggregate<>(groupByExpressions, outputExpressions,
+ olapScan);
+ return aggregate;
+ }
+
+ /**
+ * generate aggregation function according to the aggType of column
+ *
+ * @param slot slot of column
+ * @return aggFunction generated
+ */
+ private Expression generateAggFunction(SlotReference slot, Column column) {
+ AggregateType aggregateType = column.getAggregationType();
+ switch (aggregateType) {
+ case SUM:
+ return new Sum(slot);
+ case MAX:
+ return new Max(slot);
+ case MIN:
+ return new Min(slot);
+ case HLL_UNION:
+ return new HllUnion(slot);
+ case BITMAP_UNION:
+ return new BitmapUnion(slot);
+ case QUANTILE_UNION:
+ return new QuantileUnion(slot);
+ case GENERIC:
+ Type type = column.getType();
+ if (!type.isAggStateType()) {
+ return null;
+ }
+ AggStateType aggState = (AggStateType) type;
+ // use AGGREGATE_FUNCTION_UNION to aggregate multiple
agg_state into one
+ String funcName = aggState.getFunctionName() +
AggCombinerFunctionBuilder.UNION_SUFFIX;
+ FunctionRegistry functionRegistry =
Env.getCurrentEnv().getFunctionRegistry();
+ FunctionBuilder builder =
functionRegistry.findFunctionBuilder(funcName, slot);
+ return builder.build(funcName, ImmutableList.of(slot)).first;
+ default:
+ return null;
+ }
}
/**
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
deleted file mode 100644
index e547a55f9e3..00000000000
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
+++ /dev/null
@@ -1,271 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.analysis;
-
-import org.apache.doris.catalog.AggStateType;
-import org.apache.doris.catalog.AggregateType;
-import org.apache.doris.catalog.Column;
-import org.apache.doris.catalog.DistributionInfo;
-import org.apache.doris.catalog.DistributionInfo.DistributionInfoType;
-import org.apache.doris.catalog.Env;
-import org.apache.doris.catalog.FunctionRegistry;
-import org.apache.doris.catalog.KeysType;
-import org.apache.doris.catalog.OlapTable;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.ExprId;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
-import
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
-import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
-import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
-import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
-import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.qe.ConnectContext;
-
-import com.google.common.collect.ImmutableList;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Set;
-
-/**
- * build agg plan for querying random distributed table
- */
-public class BuildAggForRandomDistributedTable implements AnalysisRuleFactory {
-
- @Override
- public List<Rule> buildRules() {
- return ImmutableList.of(
- // Project(Scan) -> project(agg(scan))
- logicalProject(logicalOlapScan())
- .when(this::isQuery)
- .when(project ->
isRandomDistributedTbl(project.child()))
- .then(project -> preAggForRandomDistribution(project,
project.child()))
-
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN),
- // agg(scan) -> agg(agg(scan)), agg(agg) may optimized by
MergeAggregate
- logicalAggregate(logicalOlapScan())
- .when(this::isQuery)
- .when(agg -> isRandomDistributedTbl(agg.child()))
- .whenNot(agg -> {
- Set<AggregateFunction> functions =
agg.getAggregateFunctions();
- List<Expression> groupByExprs =
agg.getGroupByExpressions();
- // check if need generate an inner agg plan or not
- // should not rewrite twice if we had rewritten
olapScan to aggregate(olapScan)
- return
functions.stream().allMatch(this::aggTypeMatch) && groupByExprs.stream()
- .allMatch(this::isKeyOrConstantExpr);
- })
- .then(agg -> preAggForRandomDistribution(agg,
agg.child()))
-
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN),
- // filter(scan) -> filter(agg(scan))
- logicalFilter(logicalOlapScan())
- .when(this::isQuery)
- .when(filter -> isRandomDistributedTbl(filter.child()))
- .then(filter -> preAggForRandomDistribution(filter,
filter.child()))
-
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN));
-
- }
-
- /**
- * check the olapTable of olapScan is randomDistributed table
- *
- * @param olapScan olap scan plan
- * @return true if olapTable is randomDistributed table
- */
- private boolean isRandomDistributedTbl(LogicalOlapScan olapScan) {
- OlapTable olapTable = olapScan.getTable();
- KeysType keysType = olapTable.getKeysType();
- DistributionInfo distributionInfo =
olapTable.getDefaultDistributionInfo();
- return keysType == KeysType.AGG_KEYS && distributionInfo.getType() ==
DistributionInfoType.RANDOM;
- }
-
- private boolean isQuery(LogicalPlan plan) {
- return ConnectContext.get() != null
- && ConnectContext.get().getState() != null
- && ConnectContext.get().getState().isQuery();
- }
-
- /**
- * add LogicalAggregate above olapScan for preAgg
- *
- * @param logicalPlan parent plan of olapScan
- * @param olapScan olap scan plan, it may be LogicalProject,
LogicalFilter, LogicalAggregate
- * @return rewritten plan
- */
- private Plan preAggForRandomDistribution(LogicalPlan logicalPlan,
LogicalOlapScan olapScan) {
- OlapTable olapTable = olapScan.getTable();
- List<Slot> childOutputSlots = olapScan.computeOutput();
- List<Expression> groupByExpressions = new ArrayList<>();
- List<NamedExpression> outputExpressions = new ArrayList<>();
- List<Column> columns = olapTable.getBaseSchema();
-
- for (Column col : columns) {
- // use exist slot in the plan
- SlotReference slot = SlotReference.fromColumn(olapTable, col,
col.getName(), olapScan.getQualifier());
- ExprId exprId = slot.getExprId();
- for (Slot childSlot : childOutputSlots) {
- if (childSlot instanceof SlotReference && ((SlotReference)
childSlot).getName() == col.getName()) {
- exprId = childSlot.getExprId();
- slot = slot.withExprId(exprId);
- break;
- }
- }
- if (col.isKey()) {
- groupByExpressions.add(slot);
- outputExpressions.add(slot);
- } else {
- Expression function = generateAggFunction(slot, col);
- // DO NOT rewrite
- if (function == null) {
- return logicalPlan;
- }
- Alias alias = new Alias(exprId, function, col.getName());
- outputExpressions.add(alias);
- }
- }
- LogicalAggregate<LogicalOlapScan> aggregate = new
LogicalAggregate<>(groupByExpressions, outputExpressions,
- olapScan);
- return logicalPlan.withChildren(aggregate);
- }
-
- /**
- * generate aggregation function according to the aggType of column
- *
- * @param slot slot of column
- * @return aggFunction generated
- */
- private Expression generateAggFunction(SlotReference slot, Column column) {
- AggregateType aggregateType = column.getAggregationType();
- switch (aggregateType) {
- case SUM:
- return new Sum(slot);
- case MAX:
- return new Max(slot);
- case MIN:
- return new Min(slot);
- case HLL_UNION:
- return new HllUnion(slot);
- case BITMAP_UNION:
- return new BitmapUnion(slot);
- case QUANTILE_UNION:
- return new QuantileUnion(slot);
- case GENERIC:
- Type type = column.getType();
- if (!type.isAggStateType()) {
- return null;
- }
- AggStateType aggState = (AggStateType) type;
- // use AGGREGATE_FUNCTION_UNION to aggregate multiple
agg_state into one
- String funcName = aggState.getFunctionName() +
AggCombinerFunctionBuilder.UNION_SUFFIX;
- FunctionRegistry functionRegistry =
Env.getCurrentEnv().getFunctionRegistry();
- FunctionBuilder builder =
functionRegistry.findFunctionBuilder(funcName, slot);
- return builder.build(funcName, ImmutableList.of(slot)).first;
- default:
- return null;
- }
- }
-
- /**
- * if the agg type of AggregateFunction is as same as the agg type of
column, DO NOT need to rewrite
- *
- * @param function agg function to check
- * @return true if agg type match
- */
- private boolean aggTypeMatch(AggregateFunction function) {
- List<Expression> children = function.children();
- if (function.getName().equalsIgnoreCase("count")) {
- Count count = (Count) function;
- // do not rewrite for count distinct for key column
- if (count.isDistinct()) {
- return children.stream().allMatch(this::isKeyOrConstantExpr);
- }
- if (count.isStar()) {
- return false;
- }
- }
- return children.stream().allMatch(child -> aggTypeMatch(function,
child));
- }
-
- /**
- * check if the agg type of functionCall match the agg type of column
- *
- * @param function the functionCall
- * @param expression expr to check
- * @return true if agg type match
- */
- private boolean aggTypeMatch(AggregateFunction function, Expression
expression) {
- if (expression.children().isEmpty()) {
- if (expression instanceof SlotReference && ((SlotReference)
expression).getColumn().isPresent()) {
- Column col = ((SlotReference) expression).getColumn().get();
- String functionName = function.getName();
- if (col.isKey()) {
- return functionName.equalsIgnoreCase("max") ||
functionName.equalsIgnoreCase("min");
- }
- if (col.isAggregated()) {
- AggregateType aggType = col.getAggregationType();
- // agg type not mach
- if (aggType == AggregateType.GENERIC) {
- return col.getType().isAggStateType();
- }
- if (aggType == AggregateType.HLL_UNION) {
- return function instanceof HllFunction;
- }
- if (aggType == AggregateType.BITMAP_UNION) {
- return function instanceof BitmapFunction;
- }
- return functionName.equalsIgnoreCase(aggType.name());
- }
- }
- return false;
- }
- List<Expression> children = expression.children();
- return children.stream().allMatch(child -> aggTypeMatch(function,
child));
- }
-
- /**
- * check if the columns in expr is key column or constant, if group by
clause contains value column, need rewrite
- *
- * @param expr expr to check
- * @return true if all columns is key column or constant
- */
- private boolean isKeyOrConstantExpr(Expression expr) {
- if (expr instanceof SlotReference && ((SlotReference)
expr).getColumn().isPresent()) {
- Column col = ((SlotReference) expr).getColumn().get();
- return col.isKey();
- } else if (expr.isConstant()) {
- return true;
- }
- List<Expression> children = expr.children();
- return children.stream().allMatch(this::isKeyOrConstantExpr);
- }
-}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
index 94f7c36b108..4beed413d09 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import
org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
@@ -49,12 +50,23 @@ public class CheckPolicy implements AnalysisRuleFactory {
logicalCheckPolicy(any().when(child -> !(child
instanceof UnboundRelation))).thenApply(ctx -> {
LogicalCheckPolicy<Plan> checkPolicy = ctx.root;
LogicalFilter<Plan> upperFilter = null;
+ Plan upAgg = null;
Plan child = checkPolicy.child();
// Because the unique table will automatically
include a filter condition
- if (child instanceof LogicalFilter &&
child.bound() && child
- .child(0) instanceof LogicalRelation) {
+ if ((child instanceof LogicalFilter) &&
child.bound()) {
upperFilter = (LogicalFilter) child;
+ if (child.child(0) instanceof LogicalRelation)
{
+ child = child.child(0);
+ } else if (child.child(0) instanceof
LogicalAggregate
+ && child.child(0).child(0) instanceof
LogicalRelation) {
+ upAgg = child.child(0);
+ child = child.child(0).child(0);
+ }
+ }
+ if ((child instanceof LogicalAggregate)
+ && child.bound() && child.child(0)
instanceof LogicalRelation) {
+ upAgg = child;
child = child.child(0);
}
if (!(child instanceof LogicalRelation)
@@ -76,16 +88,17 @@ public class CheckPolicy implements AnalysisRuleFactory {
RelatedPolicy relatedPolicy =
checkPolicy.findPolicy(relation, ctx.cascadesContext);
relatedPolicy.rowPolicyFilter.ifPresent(expression
-> combineFilter.addAll(
ExpressionUtils.extractConjunctionToSet(expression)));
- Plan result = relation;
+ Plan result = upAgg != null ?
upAgg.withChildren(relation) : relation;
if (upperFilter != null) {
combineFilter.addAll(upperFilter.getConjuncts());
}
if (!combineFilter.isEmpty()) {
- result = new LogicalFilter<>(combineFilter,
relation);
+ result = new LogicalFilter<>(combineFilter,
result);
}
if (relatedPolicy.dataMaskProjects.isPresent()) {
result = new
LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result);
}
+
return result;
})
)
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
index b14834fd321..67115e67687 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
@@ -29,6 +29,7 @@ import org.apache.doris.nereids.rules.RulePromise;
import
org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanRewriter;
@@ -54,6 +55,12 @@ class BindRelationTest extends TestWithFeService implements
GeneratedPlanPattern
+ ")ENGINE=OLAP\n"
+ "DISTRIBUTED BY HASH(`a`) BUCKETS 3\n"
+ "PROPERTIES (\"replication_num\"= \"1\");");
+ createTable("CREATE TABLE db1.tagg ( \n"
+ + " \ta INT,\n"
+ + " \tb INT SUM\n"
+ + ")ENGINE=OLAP AGGREGATE KEY(a)\n "
+ + "DISTRIBUTED BY random BUCKETS 3\n"
+ + "PROPERTIES (\"replication_num\"= \"1\");");
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
}
@@ -125,6 +132,22 @@ class BindRelationTest extends TestWithFeService
implements GeneratedPlanPattern
);
}
+ @Test
+ void bindRandomAggTable() {
+ connectContext.setDatabase(DEFAULT_CLUSTER_PREFIX + DB1);
+ connectContext.getState().setIsQuery(true);
+ Plan plan = PlanRewriter.bottomUpRewrite(new
UnboundRelation(StatementScopeIdGenerator.newRelationId(),
ImmutableList.of("tagg")),
+ connectContext, new BindRelation());
+
+ Assertions.assertTrue(plan instanceof LogicalAggregate);
+ Assertions.assertEquals(
+ ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1,
"tagg"),
+ plan.getOutput().get(0).getQualifier());
+ Assertions.assertEquals(
+ ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1,
"tagg"),
+ plan.getOutput().get(1).getQualifier());
+ }
+
@Override
public RulePromise defaultPromise() {
return RulePromise.REWRITE;
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
index 196d99037e2..b807bbbbc7a 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
@@ -34,6 +34,9 @@ import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.FeConstants;
+import org.apache.doris.mysql.privilege.AccessControllerManager;
+import org.apache.doris.mysql.privilege.DataMaskPolicy;
+import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
@@ -41,6 +44,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.thrift.TStorageType;
@@ -48,17 +52,22 @@ import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
+import mockit.Mock;
+import mockit.MockUp;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
+import java.util.Optional;
public class CheckRowPolicyTest extends TestWithFeService {
private static String dbName = "check_row_policy";
private static String fullDbName = "" + dbName;
private static String tableName = "table1";
+
+ private static String tableNameRanddomDist = "tableRandomDist";
private static String userName = "user1";
private static String policyName = "policy1";
@@ -76,6 +85,10 @@ public class CheckRowPolicyTest extends TestWithFeService {
+ tableName
+ " (k1 int, k2 int) distributed by hash(k1) buckets 1"
+ " properties(\"replication_num\" = \"1\");");
+ createTable("create table "
+ + tableNameRanddomDist
+ + " (k1 int, k2 int) AGGREGATE KEY(k1, k2) distributed by
random buckets 1"
+ + " properties(\"replication_num\" = \"1\");");
Database db =
Env.getCurrentInternalCatalog().getDbOrMetaException(fullDbName);
long tableId = db.getTableOrMetaException("table1").getId();
olapTable.setId(tableId);
@@ -85,6 +98,7 @@ public class CheckRowPolicyTest extends TestWithFeService {
0, 0, (short) 0,
TStorageType.COLUMN,
KeysType.PRIMARY_KEYS);
+
// create user
UserIdentity user = new UserIdentity(userName, "%");
user.analyze();
@@ -98,6 +112,27 @@ public class CheckRowPolicyTest extends TestWithFeService {
Analyzer analyzer = new Analyzer(connectContext.getEnv(),
connectContext);
grantStmt.analyze(analyzer);
Env.getCurrentEnv().getAuth().grant(grantStmt);
+
+ new MockUp<AccessControllerManager>() {
+ @Mock
+ public Optional<DataMaskPolicy> evalDataMaskPolicy(UserIdentity
currentUser, String ctl,
+ String db, String tbl, String col) {
+ return tbl.equalsIgnoreCase(tableNameRanddomDist)
+ ? Optional.of(new DataMaskPolicy() {
+ @Override
+ public String getMaskTypeDef() {
+ return String.format("concat(%s, '_****_',
%s)", col, col);
+ }
+
+ @Override
+ public String getPolicyIdent() {
+ return String.format("custom policy:
concat(%s, '_****_', %s)", col,
+ col);
+ }
+ })
+ : Optional.empty();
+ }
+ };
}
@Test
@@ -115,6 +150,24 @@ public class CheckRowPolicyTest extends TestWithFeService {
Assertions.assertEquals(plan, relation);
}
+ @Test
+ public void checkUserRandomDist() throws AnalysisException,
org.apache.doris.common.AnalysisException {
+ connectContext.getState().setIsQuery(true);
+ Plan plan = PlanRewriter.bottomUpRewrite(new
UnboundRelation(StatementScopeIdGenerator.newRelationId(),
+ ImmutableList.of(tableNameRanddomDist)),
connectContext, new BindRelation());
+ LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan);
+
+ useUser("root");
+ Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy,
connectContext, new CheckPolicy(),
+ new BindExpression());
+ Assertions.assertEquals(plan, rewrittenPlan);
+
+ useUser("notFound");
+ rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy,
connectContext, new CheckPolicy(),
+ new BindExpression());
+ Assertions.assertEquals(plan, rewrittenPlan.child(0));
+ }
+
@Test
public void checkNoPolicy() throws
org.apache.doris.common.AnalysisException {
useUser(userName);
@@ -125,6 +178,18 @@ public class CheckRowPolicyTest extends TestWithFeService {
Assertions.assertEquals(plan, relation);
}
+ @Test
+ public void checkNoPolicyRandomDist() throws
org.apache.doris.common.AnalysisException {
+ useUser(userName);
+ connectContext.getState().setIsQuery(true);
+ Plan plan = PlanRewriter.bottomUpRewrite(new
UnboundRelation(StatementScopeIdGenerator.newRelationId(),
+ ImmutableList.of(tableNameRanddomDist)), connectContext, new
BindRelation());
+ LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan);
+ Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy,
connectContext, new CheckPolicy(),
+ new BindExpression());
+ Assertions.assertEquals(plan, rewrittenPlan.child(0));
+ }
+
@Test
public void checkOnePolicy() throws Exception {
useUser(userName);
@@ -152,4 +217,36 @@ public class CheckRowPolicyTest extends TestWithFeService {
+ " ON "
+ tableName);
}
+
+ @Test
+ public void checkOnePolicyRandomDist() throws Exception {
+ useUser(userName);
+ connectContext.getState().setIsQuery(true);
+ Plan plan = PlanRewriter.bottomUpRewrite(new
UnboundRelation(StatementScopeIdGenerator.newRelationId(),
+ ImmutableList.of(tableNameRanddomDist)), connectContext, new
BindRelation());
+
+ LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan);
+ connectContext.getSessionVariable().setEnableNereidsPlanner(true);
+ createPolicy("CREATE ROW POLICY "
+ + policyName
+ + " ON "
+ + tableNameRanddomDist
+ + " AS PERMISSIVE TO "
+ + userName
+ + " USING (k1 = 1)");
+ Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy,
connectContext, new CheckPolicy(),
+ new BindExpression());
+
+ Assertions.assertTrue(rewrittenPlan instanceof LogicalProject
+ && rewrittenPlan.child(0) instanceof LogicalFilter);
+ LogicalFilter filter = (LogicalFilter) rewrittenPlan.child(0);
+ Assertions.assertEquals(filter.child(), plan);
+
Assertions.assertTrue(ImmutableList.copyOf(filter.getConjuncts()).get(0)
instanceof EqualTo);
+ Assertions.assertTrue(filter.getConjuncts().toString().contains("k1#0
= 1"));
+
+ dropPolicy("DROP ROW POLICY "
+ + policyName
+ + " ON "
+ + tableNameRanddomDist);
+ }
}
diff --git
a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
index c03e72c8f9e..eb099225960 100644
--- a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
+++ b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
@@ -217,13 +217,25 @@
-- !sql_17 --
1
+3
-- !sql_18 --
1
+3
-- !sql_19 --
-1
+999999999999999.99
+1999999999999999.98
-- !sql_20 --
1
+3
+
+-- !sql_21 --
+1
+3
+
+-- !sql_22 --
+999999999999999.99
+1999999999999999.98
diff --git
a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
index c818454c261..5c99a0a4aa0 100644
---
a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
+++
b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
@@ -123,7 +123,8 @@ suite("select_random_distributed_tbl") {
// test all keys are NOT NULL for AGG table
sql "drop table if exists random_distributed_tbl_test_2;"
sql """ CREATE TABLE random_distributed_tbl_test_2 (
- `k1` LARGEINT NOT NULL
+ `k1` LARGEINT NOT NULL,
+ `k2` DECIMAL(18, 2) SUM NOT NULL
) ENGINE=OLAP
AGGREGATE KEY(`k1`)
COMMENT 'OLAP'
@@ -133,17 +134,19 @@ suite("select_random_distributed_tbl") {
);
"""
- sql """ insert into random_distributed_tbl_test_2 values(1); """
- sql """ insert into random_distributed_tbl_test_2 values(1); """
- sql """ insert into random_distributed_tbl_test_2 values(1); """
+ sql """ insert into random_distributed_tbl_test_2 values(1,
999999999999999.99); """
+ sql """ insert into random_distributed_tbl_test_2 values(1,
999999999999999.99); """
+ sql """ insert into random_distributed_tbl_test_2 values(3,
999999999999999.99); """
sql "set enable_nereids_planner = false;"
- qt_sql_17 "select k1 from random_distributed_tbl_test_2;"
- qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2;"
+ qt_sql_17 "select k1 from random_distributed_tbl_test_2 order by k1;"
+ qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2 order by
k1;"
+ qt_sql_19 "select k2 from random_distributed_tbl_test_2 order by k2;"
sql "set enable_nereids_planner = true;"
- qt_sql_19 "select k1 from random_distributed_tbl_test_2;"
- qt_sql_20 "select distinct k1 from random_distributed_tbl_test_2;"
+ qt_sql_20 "select k1 from random_distributed_tbl_test_2 order by k1;"
+ qt_sql_21 "select distinct k1 from random_distributed_tbl_test_2 order by
k1;"
+ qt_sql_22 "select k2 from random_distributed_tbl_test_2 order by k2;"
sql "drop table random_distributed_tbl_test_2;"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]