This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new dc8c5a5ad5c [feature](nereids) add rewrite rule PushCountIntoUnionAll 
(#33530)
dc8c5a5ad5c is described below

commit dc8c5a5ad5c5cf0e7d98b0c7f4a65f529fa27807
Author: feiniaofeiafei <[email protected]>
AuthorDate: Tue Oct 15 11:52:50 2024 +0800

    [feature](nereids) add rewrite rule PushCountIntoUnionAll (#33530)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   4 +-
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../nereids/rules/analysis/CheckAfterRewrite.java  |   1 -
 .../rules/rewrite/PushCountIntoUnionAll.java       | 223 +++++++++++++++++
 .../rules/rewrite/PushCountIntoUnionAllTest.java   | 170 +++++++++++++
 .../push_count_into_union_all.out                  | 265 +++++++++++++++++++++
 .../push_count_into_union_all.groovy               | 134 +++++++++++
 7 files changed, 796 insertions(+), 2 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 51c5045aa1f..9b33f94c4c2 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -110,6 +110,7 @@ import 
org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
 import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
 import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
 import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan;
+import org.apache.doris.nereids.rules.rewrite.PushCountIntoUnionAll;
 import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOnPkFk;
 import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
@@ -347,7 +348,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
                                 new PushDownAggThroughJoinOneSide(),
                                 new PushDownAggThroughJoin()
                         )),
-                        
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, 
PushDownDistinctThroughJoin::new))
+                        
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, 
PushDownDistinctThroughJoin::new)),
+                        topDown(new PushCountIntoUnionAll())
                 ),
 
                 // this rule should invoke after infer predicate and push down 
distinct, and before push down limit
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index fbff1590de2..e0875c63e65 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -284,6 +284,7 @@ public enum RuleType {
     PUSH_CONJUNCTS_INTO_ES_SCAN(RuleTypeClass.REWRITE),
     OLAP_SCAN_TABLET_PRUNE(RuleTypeClass.REWRITE),
     PUSH_AGGREGATE_TO_OLAP_SCAN(RuleTypeClass.REWRITE),
+    PUSH_COUNT_INTO_UNION_ALL(RuleTypeClass.REWRITE),
     EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE),
     HIDE_ONE_ROW_RELATION_UNDER_UNION(RuleTypeClass.REWRITE),
     PUSH_PROJECT_THROUGH_UNION(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
index 562e84275df..47cffe28c55 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
@@ -107,7 +107,6 @@ public class CheckAfterRewrite extends 
OneAnalysisRuleFactory {
         if (notFromChildren.isEmpty()) {
             return;
         }
-
         notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, 
childrenOutput);
         if (!notFromChildren.isEmpty()) {
             if (plan.arity() != 0 && plan.child(0) instanceof 
LogicalAggregate) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAll.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAll.java
new file mode 100644
index 00000000000..ddca8e479a3
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAll.java
@@ -0,0 +1,223 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
+import com.google.common.collect.Lists;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * LogicalAggregate  (groupByExpr=[c1#13], outputExpr=[c1#13, count(c1#13) AS 
`count(c1)`#15])
+ *  +--LogicalUnion (outputs=[c1#13], regularChildrenOutputs=[[c1#9], [a#4], 
[a#7]])
+ *    |--child1 (output = [[c1#9]])
+ *    |--child2 (output = [[a#4]])
+ *    +--child3 (output = [[a#7]])
+ * transform to:
+ * LogicalAggregate (groupByExpr=[c1#13], outputExpr=[c1#13, 
sum0(count(c1)#19) AS `count(c1)`#15])
+ *  +--LogicalUnion (outputs=[c1#13, count(c1)#19], 
regularChildrenOutputs=[[c1#9, count(c1)#16],
+ *   [a#4, count(a)#17], [a#7, count(a)#18]])
+ *    |--LogicalAggregate (groupByExpr=[c1#9], outputExpr=[c1#9, count(c1#9) 
AS `count(c1)`#16])
+ *    |  +--child1
+ *    |--LogicalAggregate (groupByExpr=[a#4], outputExpr=[a#4, count(a#4) AS 
`count(a)`#17])
+ *    |  +--child2
+ *    +--LogicalAggregate (groupByExpr=[a#7], outputExpr=[a#7, count(a#7) AS 
`count(a)`#18]]
+ *      +--child3
+ */
+public class PushCountIntoUnionAll implements RewriteRuleFactory {
+    @Override
+    public List<Rule> buildRules() {
+        return 
ImmutableList.of(logicalAggregate(logicalUnion().when(this::checkUnion))
+                .when(this::checkAgg)
+                .then(this::doPush)
+                .toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL),
+                
logicalAggregate(logicalProject(logicalUnion().when(this::checkUnion)))
+                .when(this::checkAgg)
+                .when(this::checkProjectUseless)
+                .then(this::removeProjectAndPush)
+                .toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL)
+                );
+    }
+
+    private Plan doPush(LogicalAggregate<LogicalUnion> agg) {
+        LogicalUnion logicalUnion = agg.child();
+        List<Slot> outputs = logicalUnion.getOutput();
+        Map<Slot, Integer> replaceMap = new HashMap<>();
+        for (int i = 0; i < outputs.size(); i++) {
+            replaceMap.put(outputs.get(i), i);
+        }
+        int childSize = logicalUnion.children().size();
+        List<Expression> upperGroupByExpressions = agg.getGroupByExpressions();
+        List<NamedExpression> upperOutputExpressions = 
agg.getOutputExpressions();
+        Builder<Plan> newChildren = 
ImmutableList.builderWithExpectedSize(childSize);
+        Builder<List<SlotReference>> childrenOutputs = 
ImmutableList.builderWithExpectedSize(childSize);
+        // create the pushed down LogicalAggregate
+        List<List<SlotReference>> childSlots = 
logicalUnion.getRegularChildrenOutputs();
+        for (int i = 0; i < childSize; i++) {
+            List<SlotReference> childOutputs = childSlots.get(i);
+            List<Expression> groupByExpressions = 
replaceExpressionByUnionAll(upperGroupByExpressions, replaceMap,
+                    childOutputs);
+            List<NamedExpression> outputExpressions = 
replaceExpressionByUnionAll(upperOutputExpressions, replaceMap,
+                    childOutputs);
+            Plan child = logicalUnion.children().get(i);
+            LogicalAggregate<Plan> logicalAggregate = new 
LogicalAggregate<>(groupByExpressions, outputExpressions,
+                    child);
+            newChildren.add(logicalAggregate);
+            childrenOutputs.add((List<SlotReference>) (List) 
logicalAggregate.getOutput());
+        }
+
+        // create the new LogicalUnion
+        LogicalSetOperation newLogicalUnion = 
logicalUnion.withChildrenAndTheirOutputs(newChildren.build(),
+                childrenOutputs.build());
+        List<NamedExpression> newLogicalUnionOutputs = Lists.newArrayList();
+        for (NamedExpression ce : upperOutputExpressions) {
+            if (ce instanceof Alias) {
+                newLogicalUnionOutputs.add(new SlotReference(ce.getName(), 
ce.getDataType(), ce.nullable()));
+            } else if (ce instanceof SlotReference) {
+                newLogicalUnionOutputs.add(ce);
+            } else {
+                return logicalUnion;
+            }
+        }
+        newLogicalUnion = 
newLogicalUnion.withNewOutputs(newLogicalUnionOutputs);
+
+        // The count in the upper agg is converted to sum0, and the alias id 
and name of the count remain unchanged.
+        Builder<NamedExpression> newUpperOutputExpressions = 
ImmutableList.builderWithExpectedSize(
+                upperOutputExpressions.size());
+        for (int i = 0; i < upperOutputExpressions.size(); i++) {
+            NamedExpression sum0Child = newLogicalUnionOutputs.get(i);
+            Expression rewrittenExpression = 
upperOutputExpressions.get(i).rewriteDownShortCircuit(expr -> {
+                if (expr instanceof Alias && ((Alias) expr).child() instanceof 
Count) {
+                    Alias alias = ((Alias) expr);
+                    return new Alias(alias.getExprId(), new Sum0(sum0Child), 
alias.getName());
+                }
+                return expr;
+            });
+            newUpperOutputExpressions.add((NamedExpression) 
rewrittenExpression);
+        }
+        return agg.withAggOutputChild(newUpperOutputExpressions.build(), 
newLogicalUnion);
+    }
+
+    private <E extends Expression> List<E> replaceExpressionByUnionAll(List<E> 
expressions,
+            Map<Slot, Integer> replaceMap, List<? extends Slot> childOutputs) {
+        // Traverse expressions. If a slot in replaceMap appears, replace it 
with childOutputs[replaceMap[slot]]
+        return ExpressionUtils.rewriteDownShortCircuit(expressions, expr -> {
+            if (expr instanceof Alias && ((Alias) expr).child() instanceof 
Count) {
+                Count cnt = (Count) ((Alias) expr).child();
+                if (cnt.isCountStar()) {
+                    return new Alias(new Count());
+                } else {
+                    Expression newCntChild = 
cnt.child(0).rewriteDownShortCircuit(e -> {
+                        if (e instanceof SlotReference && 
replaceMap.containsKey(e)) {
+                            return childOutputs.get(replaceMap.get(e));
+                        }
+                        return e;
+                    });
+                    return new Alias(new Count(newCntChild));
+                }
+            } else if (expr instanceof SlotReference && 
replaceMap.containsKey(expr)) {
+                return childOutputs.get(replaceMap.get(expr));
+            }
+            return expr;
+        });
+    }
+
+    private boolean checkAgg(LogicalAggregate aggregate) {
+        Set<Count> res = 
ExpressionUtils.collect(aggregate.getOutputExpressions(), expr -> expr 
instanceof Count);
+        if (res.isEmpty()) {
+            return false;
+        }
+        return !hasUnsuportedAggFunc(aggregate);
+    }
+
+    private boolean 
checkProjectUseless(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
+        LogicalProject<LogicalUnion> project = agg.child();
+        if (project.getProjects().size() != 1) {
+            return false;
+        }
+        if (!(project.getProjects().get(0) instanceof Alias)) {
+            return false;
+        }
+        Alias alias = (Alias) project.getProjects().get(0);
+        if (!alias.child(0).equals(new TinyIntLiteral((byte) 1))) {
+            return false;
+        }
+        List<NamedExpression> aggOutputs = agg.getOutputExpressions();
+        Slot slot = project.getOutput().get(0);
+        if (ExpressionUtils.anyMatch(aggOutputs, expr -> expr.equals(slot))) {
+            return false;
+        }
+        return true;
+    }
+
+    private Plan 
removeProjectAndPush(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
+        Plan afterRemove = agg.withChildren(agg.child().child());
+        return doPush((LogicalAggregate<LogicalUnion>) afterRemove);
+    }
+
+    private boolean hasUnsuportedAggFunc(LogicalAggregate aggregate) {
+        // only support count, not suport sum,min... and not support 
count(distinct)
+        return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(), 
expr -> {
+            if (expr instanceof AggregateFunction) {
+                if (!(expr instanceof Count)) {
+                    return true;
+                } else {
+                    return ((Count) expr).isDistinct();
+                }
+            } else {
+                return false;
+            }
+        });
+    }
+
+    private boolean checkUnion(LogicalUnion union) {
+        if (union.getQualifier() != Qualifier.ALL) {
+            return false;
+        }
+        if (union.children() == null || union.children().isEmpty()) {
+            return false;
+        }
+        if (!union.getConstantExprsList().isEmpty()) {
+            return false;
+        }
+        return true;
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java
new file mode 100644
index 00000000000..9de03dd6a57
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.utframe.TestWithFeService;
+
+import org.junit.jupiter.api.Test;
+
+public class PushCountIntoUnionAllTest extends TestWithFeService implements 
MemoPatternMatchSupported {
+    @Override
+    protected void runBeforeAll() throws Exception {
+        createDatabase("test");
+        createTable("create table test.t1 (\n"
+                + "id int not null,\n"
+                + "a varchar(128),\n"
+                + "b int,c int)\n"
+                + "distributed by hash(id) buckets 10\n"
+                + "properties('replication_num' = '1');");
+        connectContext.setDatabase("test");
+        
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
+    }
+
+    @Test
+    void testPushCountStar() {
+        String sql = "select id,count(1) from (select id,a from t1 union all 
select id,a from t1 where id>10) t group by id;";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalAggregate(
+                              logicalUnion(logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)),
+                                      logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)))
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+        String sql2 = "select id,count(*) from (select id,a from t1 union all 
select id,a from t1 where id>10) t group by id;";
+        PlanChecker.from(connectContext)
+                .analyze(sql2)
+                .rewrite()
+                .matches(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)),
+                                        logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)))
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+    }
+
+    @Test
+    void testPushCountStarNoOtherColumn() {
+        String sql = "select count(1) from (select id,a from t1 union all 
select id,a from t1 where id>10) t;";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate(), 
logicalAggregate())
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+        String sql2 = "select count(*) from (select id,a from t1 union all 
select id,a from t1 where id>10) t;";
+        PlanChecker.from(connectContext)
+                .analyze(sql2)
+                .rewrite()
+                .matches(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate(), 
logicalAggregate())
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+    }
+
+    @Test
+    void testPushCountColumn() {
+        String sql = "select count(id) from (select id,a from t1 union all 
select id,a from t1 where id>10) t;";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)),
+                                        logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)))
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+    }
+
+    @Test
+    void testPushCountColumnWithGroupBy() {
+        String sql = "select count(id),a from (select id,a from t1 union all 
select id,a from t1 where id>10) t group by a;";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+                                        && agg.getGroupByExpressions().size() 
== 1),
+                                        logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+                                                && 
agg.getGroupByExpressions().size() == 1))
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+    }
+
+    @Test
+    void testPush2CountColumn() {
+        String sql = "select count(id), count(b), a from (select id,b,a from 
t1 union all select id,a,b from t1 where id>10) t group by a;";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+                                                && 
agg.getGroupByExpressions().size() == 1),
+                                        logicalAggregate().when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)
+                                                && 
agg.getGroupByExpressions().size() == 1))
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+    }
+
+    @Test
+    void testNotPushCountBecauseOtherAggFunc() {
+        String sql = "select count(1), sum(id) from (select id,a from t1 union 
all select id,a from t1 where id>10) t;";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .nonMatch(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate(), 
logicalAggregate())
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+    }
+
+    @Test
+    void testNotPushCountBecauseUnion() {
+        String sql = "select count(1), sum(id) from (select id,a from t1 union 
select id,a from t1 where id>10) t;";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .nonMatch(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate(), 
logicalAggregate())
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+
+        String sql2 = "select count(1), sum(id) from (select id,a from t1 
union all select id,a from t1 where id>10 union all select 1,3) t;";
+        PlanChecker.from(connectContext)
+                .analyze(sql2)
+                .rewrite()
+                .nonMatch(
+                        logicalAggregate(
+                                logicalUnion(logicalAggregate(), 
logicalAggregate())
+                        ).when(agg -> 
ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
+                );
+    }
+}
diff --git 
a/regression-test/data/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.out
 
b/regression-test/data/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.out
new file mode 100644
index 00000000000..cfdef2ebfec
--- /dev/null
+++ 
b/regression-test/data/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.out
@@ -0,0 +1,265 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !count_group_by --
+1      10
+2      2
+3      6
+4      2
+5      10
+7      2
+
+-- !count_group_by_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_group_by_none --
+32
+
+-- !count_group_by_none_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_expr_group_by_none --
+32
+
+-- !count_expr_group_by_none_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !union_const --
+28
+
+-- !union_const_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------filter((mal_test_push_count.a > 1))
+----------PhysicalOlapScan[mal_test_push_count]
+--------filter((mal_test_push_count.a < 100))
+----------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_expr_group_by_ditinct_none_shape --
+PhysicalResultSink
+--hashAgg[DISTINCT_GLOBAL]
+----hashAgg[DISTINCT_LOCAL]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !union_all_child_alias --
+1      10
+2      2
+3      6
+4      2
+5      10
+7      2
+
+-- !union_all_child_alias_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----hashAgg[LOCAL]
+------PhysicalUnion
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+--------hashAgg[GLOBAL]
+----------hashAgg[LOCAL]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !union_all_child_expr --
+1      5
+3      1
+4      3
+5      1
+6      5
+8      1
+101    5
+102    1
+103    3
+104    1
+105    5
+107    1
+
+-- !count_group_by_count_other --
+1      10
+2      2
+3      6
+4      2
+5      8
+7      2
+
+-- !count_group_by_count_other_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a > 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a < 100))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a = 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+
+-- !count_group_by_multi_col --
+1      4
+1      6
+2      2
+3      2
+3      4
+4      2
+5      2
+5      6
+7      2
+
+-- !count_group_by_multi_col_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[LOCAL]
+--------------filter((mal_test_push_count.a > 1))
+----------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[LOCAL]
+--------------filter((mal_test_push_count.a < 100))
+----------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[LOCAL]
+--------------filter((mal_test_push_count.a = 1))
+----------------PhysicalOlapScan[mal_test_push_count]
+
+-- !test_upper_refer --
+1      10
+2      2
+3      6
+4      2
+5      8
+7      2
+
+-- !test_upper_refer_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a > 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a < 100))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a = 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+
+-- !unsupport_agg_func --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------filter((mal_test_push_count.a > 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a < 100))
+--------------PhysicalOlapScan[mal_test_push_count]
+------------filter((mal_test_push_count.a = 1))
+--------------PhysicalOlapScan[mal_test_push_count]
+
+-- !test_upper_refer_count_star --
+1      10
+2      2
+3      6
+4      2
+5      10
+7      2
+
+-- !test_upper_refer_count_star_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalQuickSort[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------PhysicalUnion
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a > 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a < 100))
+------------------PhysicalOlapScan[mal_test_push_count]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((mal_test_push_count.a = 1))
+------------------PhysicalOlapScan[mal_test_push_count]
+
+-- !test_count_star --
+32
+
diff --git 
a/regression-test/suites/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.groovy
 
b/regression-test/suites/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.groovy
new file mode 100644
index 00000000000..71e07318820
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/push_count_into_union_all/push_count_into_union_all.groovy
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("push_count_into_union_all") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+    sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"""
+    sql """
+          DROP TABLE IF EXISTS mal_test_push_count
+         """
+
+    sql """
+         create table mal_test_push_count(pk int, a int, b int) distributed by 
hash(pk) buckets 10
+         properties('replication_num' = '1'); 
+         """
+
+    sql """
+         insert into mal_test_push_count 
values(2,1,3),(1,1,2),(3,5,6),(6,null,6),(4,5,6),(2,1,4),(2,3,5),(1,1,4)
+        
,(3,5,6),(3,5,null),(6,7,1),(2,1,7),(2,4,2),(2,3,9),(1,3,6),(3,5,8),(3,2,8);
+      """
+    sql "sync"
+
+    qt_count_group_by """
+        select a,count(a) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by a 
order by 1,2;"""
+
+    qt_count_group_by_shape """explain shape plan
+        select a,count(a) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by 
a;"""
+
+    qt_count_group_by_none """
+        select count(a) c1 from (select a,b from mal_test_push_count where a>1 
union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t order by 
1"""
+
+    qt_count_group_by_none_shape """explain shape plan
+        select count(a) c1 from (select a,b from mal_test_push_count where a>1 
union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t"""
+
+    qt_count_expr_group_by_none """
+        select count(a+1) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t order by 
1"""
+
+    qt_count_expr_group_by_none_shape """explain shape plan
+        select count(a+1) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t"""
+
+    qt_union_const """
+        select count(a+1) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select 1,2 ) t order by 1"""
+
+    qt_union_const_shape """explain shape plan
+        select count(a+1) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select 1,2 ) t"""
+
+    qt_count_expr_group_by_ditinct_none_shape """explain shape plan
+        select count(distinct a+1) c1 from (select a,b from 
mal_test_push_count where a>1 union all select a,b from mal_test_push_count 
where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t"""
+
+    qt_union_all_child_alias """
+        select c1,count(c1)  from (select a as c1,b as c2 from 
mal_test_push_count where a>1 union all select a,b from mal_test_push_count 
where a<100
+        union all select a,b from mal_test_push_count where a=1 ) t group by 
c1 order by 1,2;"""
+
+    qt_union_all_child_alias_shape """
+        explain shape plan
+        select c1,count(c1)  from (select a as c1,b as c2 from 
mal_test_push_count where a>1 union all select a,b from mal_test_push_count 
where a<100
+        union all select a,b from mal_test_push_count where a=1 ) t group by 
c1;"""
+
+    qt_union_all_child_expr """
+        select c1,count(c1)  from (select a+1 as c1,b as c2 from 
mal_test_push_count where a>1 union all select a+100,b from mal_test_push_count 
where a<100
+        union all select abs(a),b from mal_test_push_count where a=1 ) t group 
by c1 order by 1,2;"""
+
+    qt_count_group_by_count_other """
+        select a,count(b) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by a 
order by 1,2;"""
+
+    qt_count_group_by_count_other_shape """
+        explain shape plan
+        select a,count(b) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by a 
order by 1,2;"""
+
+    qt_count_group_by_multi_col """
+        select a,count(b) c1 from (select a,b,pk from mal_test_push_count 
where a>1 union all select a,b,pk from mal_test_push_count where a<100 
+        union all select a,b,pk from mal_test_push_count where a=1 ) t group 
by a,pk order by 1,2;"""
+
+    qt_count_group_by_multi_col_shape """
+        explain shape plan
+        select a,count(b) c1 from (select a,b,pk from mal_test_push_count 
where a>1 union all select a,b,pk from mal_test_push_count where a<100
+        union all select a,b,pk from mal_test_push_count where a=1 ) t group 
by a,pk order by 1,2;"""
+
+    qt_test_upper_refer """
+        select a,c1 from (
+        select a,count(b) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by 
a) outer_table order by 1,2;"""
+
+    qt_test_upper_refer_shape """
+        explain shape plan
+        select a,c1 from (
+        select a,count(b) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by 
a) outer_table order by 1,2;"""
+
+    qt_unsupport_agg_func """
+        explain shape plan
+        select a,count(b) c1,sum(b) from (select a,b from mal_test_push_count 
where a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by a 
order by 1,2,3;"""
+
+    qt_test_upper_refer_count_star """
+        select a,c1 from (
+        select a,count(*) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by 
a) outer_table order by 1,2;"""
+
+    qt_test_upper_refer_count_star_shape """
+        explain shape plan
+        select a,c1 from (
+        select a,count(*) c1 from (select a,b from mal_test_push_count where 
a>1 union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1 ) t group by 
a) outer_table order by 1,2;"""
+
+    qt_test_count_star """
+        select count(*) from (select a,b from mal_test_push_count where a>1 
union all select a,b from mal_test_push_count where a<100 
+        union all select a,b from mal_test_push_count where a=1) t order by 
1,2;"""
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to