This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new dc8c5a5ad5c [feature](nereids) add rewrite rule PushCountIntoUnionAll
(#33530)
dc8c5a5ad5c is described below
commit dc8c5a5ad5c5cf0e7d98b0c7f4a65f529fa27807
Author: feiniaofeiafei <[email protected]>
AuthorDate: Tue Oct 15 11:52:50 2024 +0800
[feature](nereids) add rewrite rule PushCountIntoUnionAll (#33530)
---
.../doris/nereids/jobs/executor/Rewriter.java | 4 +-
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../nereids/rules/analysis/CheckAfterRewrite.java | 1 -
.../rules/rewrite/PushCountIntoUnionAll.java | 223 +++++++++++++++++
.../rules/rewrite/PushCountIntoUnionAllTest.java | 170 +++++++++++++
.../push_count_into_union_all.out | 265 +++++++++++++++++++++
.../push_count_into_union_all.groovy | 134 +++++++++++
7 files changed, 796 insertions(+), 2 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 51c5045aa1f..9b33f94c4c2 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -110,6 +110,7 @@ import
org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan;
+import org.apache.doris.nereids.rules.rewrite.PushCountIntoUnionAll;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOnPkFk;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
@@ -347,7 +348,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushDownAggThroughJoinOneSide(),
new PushDownAggThroughJoin()
)),
-
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN,
PushDownDistinctThroughJoin::new))
+
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN,
PushDownDistinctThroughJoin::new)),
+ topDown(new PushCountIntoUnionAll())
),
// this rule should invoke after infer predicate and push down
distinct, and before push down limit
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index fbff1590de2..e0875c63e65 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -284,6 +284,7 @@ public enum RuleType {
PUSH_CONJUNCTS_INTO_ES_SCAN(RuleTypeClass.REWRITE),
OLAP_SCAN_TABLET_PRUNE(RuleTypeClass.REWRITE),
PUSH_AGGREGATE_TO_OLAP_SCAN(RuleTypeClass.REWRITE),
+ PUSH_COUNT_INTO_UNION_ALL(RuleTypeClass.REWRITE),
EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE),
HIDE_ONE_ROW_RELATION_UNDER_UNION(RuleTypeClass.REWRITE),
PUSH_PROJECT_THROUGH_UNION(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
index 562e84275df..47cffe28c55 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
@@ -107,7 +107,6 @@ public class CheckAfterRewrite extends
OneAnalysisRuleFactory {
if (notFromChildren.isEmpty()) {
return;
}
-
notFromChildren = removeValidSlotsNotFromChildren(notFromChildren,
childrenOutput);
if (!notFromChildren.isEmpty()) {
if (plan.arity() != 0 && plan.child(0) instanceof
LogicalAggregate) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAll.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAll.java
new file mode 100644
index 00000000000..ddca8e479a3
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAll.java
@@ -0,0 +1,223 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
+import com.google.common.collect.Lists;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * LogicalAggregate (groupByExpr=[c1#13], outputExpr=[c1#13, count(c1#13) AS
`count(c1)`#15])
+ * +--LogicalUnion (outputs=[c1#13], regularChildrenOutputs=[[c1#9], [a#4],
[a#7]])
+ * |--child1 (output = [[c1#9]])
+ * |--child2 (output = [[a#4]])
+ * +--child3 (output = [[a#7]])
+ * transform to:
+ * LogicalAggregate (groupByExpr=[c1#13], outputExpr=[c1#13,
sum0(count(c1)#19) AS `count(c1)`#15])
+ * +--LogicalUnion (outputs=[c1#13, count(c1)#19],
regularChildrenOutputs=[[c1#9, count(c1)#16],
+ * [a#4, count(a)#17], [a#7, count(a)#18]])
+ * |--LogicalAggregate (groupByExpr=[c1#9], outputExpr=[c1#9, count(c1#9)
AS `count(c1)`#16])
+ * | +--child1
+ * |--LogicalAggregate (groupByExpr=[a#4], outputExpr=[a#4, count(a#4) AS
`count(a)`#17])
+ * | +--child2
+ * +--LogicalAggregate (groupByExpr=[a#7], outputExpr=[a#7, count(a#7) AS
`count(a)`#18]]
+ * +--child3
+ */
+public class PushCountIntoUnionAll implements RewriteRuleFactory {
+ @Override
+ public List<Rule> buildRules() {
+ return
ImmutableList.of(logicalAggregate(logicalUnion().when(this::checkUnion))
+ .when(this::checkAgg)
+ .then(this::doPush)
+ .toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL),
+
logicalAggregate(logicalProject(logicalUnion().when(this::checkUnion)))
+ .when(this::checkAgg)
+ .when(this::checkProjectUseless)
+ .then(this::removeProjectAndPush)
+ .toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL)
+ );
+ }
+
+ private Plan doPush(LogicalAggregate<LogicalUnion> agg) {
+ LogicalUnion logicalUnion = agg.child();
+ List<Slot> outputs = logicalUnion.getOutput();
+ Map<Slot, Integer> replaceMap = new HashMap<>();
+ for (int i = 0; i < outputs.size(); i++) {
+ replaceMap.put(outputs.get(i), i);
+ }
+ int childSize = logicalUnion.children().size();
+ List<Expression> upperGroupByExpressions = agg.getGroupByExpressions();
+ List<NamedExpression> upperOutputExpressions =
agg.getOutputExpressions();
+ Builder<Plan> newChildren =
ImmutableList.builderWithExpectedSize(childSize);
+ Builder<List<SlotReference>> childrenOutputs =
ImmutableList.builderWithExpectedSize(childSize);
+ // create the pushed down LogicalAggregate
+ List<List<SlotReference>> childSlots =
logicalUnion.getRegularChildrenOutputs();
+ for (int i = 0; i < childSize; i++) {
+ List<SlotReference> childOutputs = childSlots.get(i);
+ List<Expression> groupByExpressions =
replaceExpressionByUnionAll(upperGroupByExpressions, replaceMap,
+ childOutputs);
+ List<NamedExpression> outputExpressions =
replaceExpressionByUnionAll(upperOutputExpressions, replaceMap,
+ childOutputs);
+ Plan child = logicalUnion.children().get(i);
+ LogicalAggregate<Plan> logicalAggregate = new
LogicalAggregate<>(groupByExpressions, outputExpressions,
+ child);
+ newChildren.add(logicalAggregate);
+ childrenOutputs.add((List<SlotReference>) (List)
logicalAggregate.getOutput());
+ }
+
+ // create the new LogicalUnion
+ LogicalSetOperation newLogicalUnion =
logicalUnion.withChildrenAndTheirOutputs(newChildren.build(),
+ childrenOutputs.build());
+ List<NamedExpression> newLogicalUnionOutputs = Lists.newArrayList();
+ for (NamedExpression ce : upperOutputExpressions) {
+ if (ce instanceof Alias) {
+ newLogicalUnionOutputs.add(new SlotReference(ce.getName(),
ce.getDataType(), ce.nullable()));
+ } else if (ce instanceof SlotReference) {
+ newLogicalUnionOutputs.add(ce);
+ } else {
+ return logicalUnion;
+ }
+ }
+ newLogicalUnion =
newLogicalUnion.withNewOutputs(newLogicalUnionOutputs);
+
+ // The count in the upper agg is converted to sum0, and the alias id
and name of the count remain unchanged.
+ Builder<NamedExpression> newUpperOutputExpressions =
ImmutableList.builderWithExpectedSize(
+ upperOutputExpressions.size());
+ for (int i = 0; i < upperOutputExpressions.size(); i++) {
+ NamedExpression sum0Child = newLogicalUnionOutputs.get(i);
+ Expression rewrittenExpression =
upperOutputExpressions.get(i).rewriteDownShortCircuit(expr -> {
+ if (expr instanceof Alias && ((Alias) expr).child() instanceof
Count) {
+ Alias alias = ((Alias) expr);
+ return new Alias(alias.getExprId(), new Sum0(sum0Child),
alias.getName());
+ }
+ return expr;
+ });
+ newUpperOutputExpressions.add((NamedExpression)
rewrittenExpression);
+ }
+ return agg.withAggOutputChild(newUpperOutputExpressions.build(),
newLogicalUnion);
+ }
+
+ private <E extends Expression> List<E> replaceExpressionByUnionAll(List<E>
expressions,
+ Map<Slot, Integer> replaceMap, List<? extends Slot> childOutputs) {
+ // Traverse expressions. If a slot in replaceMap appears, replace it
with childOutputs[replaceMap[slot]]
+ return ExpressionUtils.rewriteDownShortCircuit(expressions, expr -> {
+ if (expr instanceof Alias && ((Alias) expr).child() instanceof
Count) {
+ Count cnt = (Count) ((Alias) expr).child();
+ if (cnt.isCountStar()) {
+ return new Alias(new Count());
+ } else {
+ Expression newCntChild =
cnt.child(0).rewriteDownShortCircuit(e -> {
+ if (e instanceof SlotReference &&
replaceMap.containsKey(e)) {
+ return childOutputs.get(replaceMap.get(e));
+ }
+ return e;
+ });
+ return new Alias(new Count(newCntChild));
+ }
+ } else if (expr instanceof SlotReference &&
replaceMap.containsKey(expr)) {
+ return childOutputs.get(replaceMap.get(expr));
+ }
+ return expr;
+ });
+ }
+
+ private boolean checkAgg(LogicalAggregate aggregate) {
+ Set<Count> res =
ExpressionUtils.collect(aggregate.getOutputExpressions(), expr -> expr
instanceof Count);
+ if (res.isEmpty()) {
+ return false;
+ }
+ return !hasUnsuportedAggFunc(aggregate);
+ }
+
+ private boolean
checkProjectUseless(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
+ LogicalProject<LogicalUnion> project = agg.child();
+ if (project.getProjects().size() != 1) {
+ return false;
+ }
+ if (!(project.getProjects().get(0) instanceof Alias)) {
+ return false;
+ }
+ Alias alias = (Alias) project.getProjects().get(0);
+ if (!alias.child(0).equals(new TinyIntLiteral((byte) 1))) {
+ return false;
+ }
+ List<NamedExpression> aggOutputs = agg.getOutputExpressions();
+ Slot slot = project.getOutput().get(0);
+ if (ExpressionUtils.anyMatch(aggOutputs, expr -> expr.equals(slot))) {
+ return false;
+ }
+ return true;
+ }
+
+ private Plan
removeProjectAndPush(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
+ Plan afterRemove = agg.withChildren(agg.child().child());
+ return doPush((LogicalAggregate<LogicalUnion>) afterRemove);
+ }
+
+ private boolean hasUnsuportedAggFunc(LogicalAggregate aggregate) {
+ // only support count, not suport sum,min... and not support
count(distinct)
+ return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(),
expr -> {
+ if (expr instanceof AggregateFunction) {
+ if (!(expr instanceof Count)) {
+ return true;
+ } else {
+ return ((Count) expr).isDistinct();
+ }
+ } else {
+ return false;
+ }
+ });
+ }
+
+ private boolean checkUnion(LogicalUnion union) {
+ if (union.getQualifier() != Qualifier.ALL) {
+ return false;
+ }
+ if (union.children() == null || union.children().isEmpty()) {
+ return false;
+ }
+ if (!union.getConstantExprsList().isEmpty()) {
+ return false;
+ }
+ return true;
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java
new file mode 100644
index 00000000000..9de03dd6a57
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.utframe.TestWithFeService;
+
+import org.junit.jupiter.api.Test;
+
+public class PushCountIntoUnionAllTest extends TestWithFeService implements
MemoPatternMatchSupported {
+ @Override
+ protected void runBeforeAll() throws Exception {
+ createDatabase("test");
+ createTable("create table test.t1 (\n"
+ + "id int not null,\n"
+ + "a varchar(128),\n"
+ + "b int,c int)\n"
+ + "distributed by hash(id) buckets 10\n"
+ + "properties('replication_num' = '1');");
+ connectContext.setDatabase("test");
+
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
+ }
+
+ @Test
+ void testPushCountStar() {
+ String sql = "select id,count(1) from (select id,a from t1 union all
select id,a from t1 where id>10) t group by id;";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalAggregate(
+ logicalUnion(logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)),
+ logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)))
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ String sql2 = "select id,count(*) from (select id,a from t1 union all
select id,a from t1 where id>10) t group by id;";
+ PlanChecker.from(connectContext)
+ .analyze(sql2)
+ .rewrite()
+ .matches(
+ logicalAggregate(
+ logicalUnion(logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)),
+ logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)))
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ }
+
+ @Test
+ void testPushCountStarNoOtherColumn() {
+ String sql = "select count(1) from (select id,a from t1 union all
select id,a from t1 where id>10) t;";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalAggregate(
+ logicalUnion(logicalAggregate(),
logicalAggregate())
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ String sql2 = "select count(*) from (select id,a from t1 union all
select id,a from t1 where id>10) t;";
+ PlanChecker.from(connectContext)
+ .analyze(sql2)
+ .rewrite()
+ .matches(
+ logicalAggregate(
+ logicalUnion(logicalAggregate(),
logicalAggregate())
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ }
+
+ @Test
+ void testPushCountColumn() {
+ String sql = "select count(id) from (select id,a from t1 union all
select id,a from t1 where id>10) t;";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalAggregate(
+ logicalUnion(logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)),
+ logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)))
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ }
+
+ @Test
+ void testPushCountColumnWithGroupBy() {
+ String sql = "select count(id),a from (select id,a from t1 union all
select id,a from t1 where id>10) t group by a;";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalAggregate(
+ logicalUnion(logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+ && agg.getGroupByExpressions().size()
== 1),
+ logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+ &&
agg.getGroupByExpressions().size() == 1))
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ }
+
+ @Test
+ void testPush2CountColumn() {
+ String sql = "select count(id), count(b), a from (select id,b,a from
t1 union all select id,a,b from t1 where id>10) t group by a;";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalAggregate(
+ logicalUnion(logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+ &&
agg.getGroupByExpressions().size() == 1),
+ logicalAggregate().when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+ &&
agg.getGroupByExpressions().size() == 1))
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ }
+
+ @Test
+ void testNotPushCountBecauseOtherAggFunc() {
+ String sql = "select count(1), sum(id) from (select id,a from t1 union
all select id,a from t1 where id>10) t;";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .nonMatch(
+ logicalAggregate(
+ logicalUnion(logicalAggregate(),
logicalAggregate())
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ }
+
+ @Test
+ void testNotPushCountBecauseUnion() {
+ String sql = "select count(1), sum(id) from (select id,a from t1 union
select id,a from t1 where id>10) t;";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .nonMatch(
+ logicalAggregate(
+ logicalUnion(logicalAggregate(),
logicalAggregate())
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+
+ String sql2 = "select count(1), sum(id) from (select id,a from t1
union all select id,a from t1 where id>10 union all select 1,3) t;";
+ PlanChecker.from(connectContext)
+ .analyze(sql2)
+ .rewrite()
+ .nonMatch(
+ logicalAggregate(
+ logicalUnion(logicalAggregate(),
logicalAggregate())
+ ).when(agg ->
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+ );
+ }
+}
diff --git
a/regression-test/data/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.out
b/regression-test/data/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.out
new file mode 100644
index 00000000000..cfdef2ebfec
--- /dev/null
+++
b/regression-test/data/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.out
@@ -0,0 +1,265 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !count_group_by --
+1 10
+2 2
+3 6
+4 2
+5 10
+7 2
+
+-- !count_group_by_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_group_by_none --
+32
+
+-- !count_group_by_none_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_expr_group_by_none --
+32
+
+-- !count_expr_group_by_none_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !union_const --
+28
+
+-- !union_const_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------filter((mal_test_push_count.a > 1))
+----------PhysicalOlapScan[mal_test_push_count]
+--------filter((mal_test_push_count.a < 100))
+----------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_expr_group_by_ditinct_none_shape --
+PhysicalResultSink
+--hashAgg[DISTINCT_GLOBAL]
+----hashAgg[DISTINCT_LOCAL]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !union_all_child_alias --
+1 10
+2 2
+3 6
+4 2
+5 10
+7 2
+
+-- !union_all_child_alias_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !union_all_child_expr --
+1 5
+3 1
+4 3
+5 1
+6 5
+8 1
+101 5
+102 1
+103 3
+104 1
+105 5
+107 1
+
+-- !count_group_by_count_other --
+1 10
+2 2
+3 6
+4 2
+5 8
+7 2
+
+-- !count_group_by_count_other_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a > 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a < 100))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a = 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_group_by_multi_col --
+1 4
+1 6
+2 2
+3 2
+3 4
+4 2
+5 2
+5 6
+7 2
+
+-- !count_group_by_multi_col_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[LOCAL]
+--------------filter((mal_test_push_count.a > 1))
+----------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[LOCAL]
+--------------filter((mal_test_push_count.a < 100))
+----------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[LOCAL]
+--------------filter((mal_test_push_count.a = 1))
+----------------PhysicalOlapScan[mal_test_push_count]
+
+-- !test_upper_refer --
+1 10
+2 2
+3 6
+4 2
+5 8
+7 2
+
+-- !test_upper_refer_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a > 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a < 100))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a = 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+
+-- !unsupport_agg_func --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !test_upper_refer_count_star --
+1 10
+2 2
+3 6
+4 2
+5 10
+7 2
+
+-- !test_upper_refer_count_star_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a > 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a < 100))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a = 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+
+-- !test_count_star --
+32
+
diff --git
a/regression-test/suites/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.groovy
b/regression-test/suites/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.groovy
new file mode 100644
index 00000000000..71e07318820
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.groovy
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("push_count_into_union_all") {
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+ sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"""
+ sql """
+ DROP TABLE IF EXISTS mal_test_push_count
+ """
+
+ sql """
+ create table mal_test_push_count(pk int, a int, b int) distributed by
hash(pk) buckets 10
+ properties('replication_num' = '1');
+ """
+
+ sql """
+ insert into mal_test_push_count
values(2,1,3),(1,1,2),(3,5,6),(6,null,6),(4,5,6),(2,1,4),(2,3,5),(1,1,4)
+
,(3,5,6),(3,5,null),(6,7,1),(2,1,7),(2,4,2),(2,3,9),(1,3,6),(3,5,8),(3,2,8);
+ """
+ sql "sync"
+
+ qt_count_group_by """
+ select a,count(a) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by a
order by 1,2;"""
+
+ qt_count_group_by_shape """explain shape plan
+ select a,count(a) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by
a;"""
+
+ qt_count_group_by_none """
+ select count(a) c1 from (select a,b from mal_test_push_count where a>1
union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t order by
1"""
+
+ qt_count_group_by_none_shape """explain shape plan
+ select count(a) c1 from (select a,b from mal_test_push_count where a>1
union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t"""
+
+ qt_count_expr_group_by_none """
+ select count(a+1) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t order by
1"""
+
+ qt_count_expr_group_by_none_shape """explain shape plan
+ select count(a+1) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t"""
+
+ qt_union_const """
+ select count(a+1) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select 1,2 ) t order by 1"""
+
+ qt_union_const_shape """explain shape plan
+ select count(a+1) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select 1,2 ) t"""
+
+ qt_count_expr_group_by_ditinct_none_shape """explain shape plan
+ select count(distinct a+1) c1 from (select a,b from
mal_test_push_count where a>1 union all select a,b from mal_test_push_count
where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t"""
+
+ qt_union_all_child_alias """
+ select c1,count(c1) from (select a as c1,b as c2 from
mal_test_push_count where a>1 union all select a,b from mal_test_push_count
where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by
c1 order by 1,2;"""
+
+ qt_union_all_child_alias_shape """
+ explain shape plan
+ select c1,count(c1) from (select a as c1,b as c2 from
mal_test_push_count where a>1 union all select a,b from mal_test_push_count
where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by
c1;"""
+
+ qt_union_all_child_expr """
+ select c1,count(c1) from (select a+1 as c1,b as c2 from
mal_test_push_count where a>1 union all select a+100,b from mal_test_push_count
where a<100
+ union all select abs(a),b from mal_test_push_count where a=1 ) t group
by c1 order by 1,2;"""
+
+ qt_count_group_by_count_other """
+ select a,count(b) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by a
order by 1,2;"""
+
+ qt_count_group_by_count_other_shape """
+ explain shape plan
+ select a,count(b) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by a
order by 1,2;"""
+
+ qt_count_group_by_multi_col """
+ select a,count(b) c1 from (select a,b,pk from mal_test_push_count
where a>1 union all select a,b,pk from mal_test_push_count where a<100
+ union all select a,b,pk from mal_test_push_count where a=1 ) t group
by a,pk order by 1,2;"""
+
+ qt_count_group_by_multi_col_shape """
+ explain shape plan
+ select a,count(b) c1 from (select a,b,pk from mal_test_push_count
where a>1 union all select a,b,pk from mal_test_push_count where a<100
+ union all select a,b,pk from mal_test_push_count where a=1 ) t group
by a,pk order by 1,2;"""
+
+ qt_test_upper_refer """
+ select a,c1 from (
+ select a,count(b) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by
a) outer_table order by 1,2;"""
+
+ qt_test_upper_refer_shape """
+ explain shape plan
+ select a,c1 from (
+ select a,count(b) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by
a) outer_table order by 1,2;"""
+
+ qt_unsupport_agg_func """
+ explain shape plan
+ select a,count(b) c1,sum(b) from (select a,b from mal_test_push_count
where a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by a
order by 1,2,3;"""
+
+ qt_test_upper_refer_count_star """
+ select a,c1 from (
+ select a,count(*) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by
a) outer_table order by 1,2;"""
+
+ qt_test_upper_refer_count_star_shape """
+ explain shape plan
+ select a,c1 from (
+ select a,count(*) c1 from (select a,b from mal_test_push_count where
a>1 union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1 ) t group by
a) outer_table order by 1,2;"""
+
+ qt_test_count_star """
+ select count(*) from (select a,b from mal_test_push_count where a>1
union all select a,b from mal_test_push_count where a<100
+ union all select a,b from mal_test_push_count where a=1) t order by
1,2;"""
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]