This is an automated email from the ASF dual-hosted git repository.

starocean999 pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new bcde9c65249 [enhancement](nereids)eliminate repeat node if there is 
only 1 grouping set and no grouping scalar function (#35872)
bcde9c65249 is described below

commit bcde9c65249ba1e5469a29bd6453acbbf003f0ab
Author: starocean999 <[email protected]>
AuthorDate: Wed Jun 5 18:03:20 2024 +0800

    [enhancement](nereids)eliminate repeat node if there is only 1 grouping set 
and no grouping scalar function (#35872)
---
 .../nereids/rules/analysis/NormalizeRepeat.java    |  6 ++++
 .../rules/analysis/NormalizeRepeatTest.java        | 39 +++++++++++++++++++++-
 .../grouping_sets/grouping_normalize_test.groovy   | 15 +++++++++
 3 files changed, 59 insertions(+), 1 deletion(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
index 169d5a901a7..2d39852dd18 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
@@ -82,6 +82,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
     public Rule build() {
         return RuleType.NORMALIZE_REPEAT.build(
             
logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> {
+                if (repeat.getGroupingSets().size() == 1
+                        && 
ExpressionUtils.collect(repeat.getOutputExpressions(),
+                        GroupingScalarFunction.class::isInstance).isEmpty()) {
+                    return new 
LogicalAggregate<>(repeat.getGroupByExpressions(),
+                            repeat.getOutputExpressions(), repeat.child());
+                }
                 checkRepeatLegality(repeat);
                 repeat = removeDuplicateColumns(repeat);
                 // add virtual slot, LogicalAggregate and LogicalProject for 
normalize
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
index 3fc2fec9a65..556f5279412 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.analysis;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
@@ -41,7 +42,7 @@ public class NormalizeRepeatTest implements 
MemoPatternMatchSupported {
         Slot name = scan1.getOutput().get(1);
         Alias alias = new Alias(new Sum(name), "sum(name)");
         Plan plan = new LogicalRepeat<>(
-                ImmutableList.of(ImmutableList.of(id)),
+                ImmutableList.of(ImmutableList.of(id), ImmutableList.of(name)),
                 ImmutableList.of(idNotNull, alias),
                 scan1
         );
@@ -51,4 +52,40 @@ public class NormalizeRepeatTest implements 
MemoPatternMatchSupported {
                         logicalRepeat().when(repeat -> 
repeat.getOutputExpressions().get(0).nullable())
                 );
     }
+
+    @Test
+    public void testEliminateRepeat() {
+        Slot id = scan1.getOutput().get(0);
+        Slot idNotNull = id.withNullable(true);
+        Slot name = scan1.getOutput().get(1);
+        Alias alias = new Alias(new Sum(name), "sum(name)");
+        Plan plan = new LogicalRepeat<>(
+                ImmutableList.of(ImmutableList.of(id)),
+                ImmutableList.of(idNotNull, alias),
+                scan1
+        );
+        PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
+                .applyTopDown(new NormalizeRepeat())
+                .matchesFromRoot(
+                        logicalAggregate(logicalOlapScan())
+                );
+    }
+
+    @Test
+    public void testNoEliminateRepeat() {
+        Slot id = scan1.getOutput().get(0);
+        Slot idNotNull = id.withNullable(true);
+        Slot name = scan1.getOutput().get(1);
+        Alias alias = new Alias(new GroupingId(name), "grouping_id(name)");
+        Plan plan = new LogicalRepeat<>(
+                ImmutableList.of(ImmutableList.of(id)),
+                ImmutableList.of(idNotNull, alias),
+                scan1
+        );
+        PlanChecker.from(MemoTestUtils.createCascadesContext(plan))
+                .applyTopDown(new NormalizeRepeat())
+                .matchesFromRoot(
+                        logicalAggregate(logicalRepeat(logicalOlapScan()))
+                );
+    }
 }
diff --git 
a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy
 
b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy
index 8310685c9c6..93821452f2f 100644
--- 
a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy
@@ -39,4 +39,19 @@ suite("grouping_normalize_test"){
     SELECT  ROUND( SUM(pk  +  1)  -  3)  col_alias1,  MAX( DISTINCT  
col_int_undef_signed  -  5)   AS col_alias2, pk  +  1  AS col_alias3
     FROM grouping_normalize_test  GROUP BY  GROUPING SETS 
((col_int_undef_signed,col_int_undef_signed2,pk),()) order by 1,2,3;
     """
+
+    explain {
+            sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) 
FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, 
col_int_undef_signed2));")
+            notContains("VREPEAT_NODE")
+    }
+
+    explain {
+            sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk), 
grouping_id(col_int_undef_signed2) FROM grouping_normalize_test GROUP BY 
GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2),());")
+            contains("VREPEAT_NODE")
+    }
+
+    explain {
+            sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) 
FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, 
col_int_undef_signed2));")
+            notContains("VREPEAT_NODE")
+    }
 }
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to