This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new ebe302cb7e6 branch-2.1: [fix](nereids) do eliminate constant group by 
key in normalizeagg #49589 (#50212)
ebe302cb7e6 is described below

commit ebe302cb7e690199de814fb2634aacde6d70a65e
Author: feiniaofeiafei <[email protected]>
AuthorDate: Thu May 8 18:52:40 2025 +0800

    branch-2.1: [fix](nereids) do eliminate constant group by key in 
normalizeagg #49589 (#50212)
    
    Cherry-picked from https://github.com/apache/doris/pull/49589
---
 .../doris/nereids/jobs/executor/Analyzer.java      |   3 -
 .../doris/nereids/jobs/executor/Rewriter.java      |   2 -
 .../nereids/rules/analysis/NormalizeAggregate.java | 108 ++++++++++++-
 .../analysis/EliminateGroupByConstantTest.java     | 165 --------------------
 .../rules/analysis/NormalizeAggregateTest.java     | 121 ++++++++++++++-
 .../eliminate_constant_gby_key.out                 | Bin 0 -> 3009 bytes
 .../eliminate_constant_gby_key.groovy              | 172 +++++++++++++++++++++
 7 files changed, 392 insertions(+), 179 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index 5b8e52c0289..24129a10834 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -30,7 +30,6 @@ import org.apache.doris.nereids.rules.analysis.CheckPolicy;
 import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
 import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
 import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
-import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
 import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
 import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
 import org.apache.doris.nereids.rules.analysis.HavingToFilter;
@@ -136,8 +135,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
             // select SUM(lo_tax) FROM lineorder group by 1;
             // errCode = 2, detailMessage = GROUP BY expression must not 
contain aggregate functions: sum(lo_tax)
             bottomUp(new CheckAnalysis()),
-            topDown(new EliminateGroupByConstant()),
-
             topDown(new SimplifyAggGroupBy()),
             topDown(new NormalizeAggregate()),
             topDown(new HavingToFilter()),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 122d6a6cf63..4c712e919ba 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -25,7 +25,6 @@ import org.apache.doris.nereids.rules.RuleType;
 import 
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
 import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
 import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
-import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
 import 
org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
 import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
 import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
@@ -158,7 +157,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
                         topDown(
                                 new EliminateOrderByConstant(),
                                 new EliminateSortUnderSubqueryOrView(),
-                                new EliminateGroupByConstant(),
                                 // MergeProjects depends on this rule
                                 new LogicalSubQueryAliasToLogicalProject(),
                                 // TODO: we should do expression normalization 
after plan normalization
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index e5ebee120a3..4a2e226caae 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -17,9 +17,12 @@
 
 package org.apache.doris.nereids.rules.analysis;
 
+import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
 import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
 import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
 import org.apache.doris.nereids.trees.expressions.Alias;
@@ -35,6 +38,7 @@ import 
org.apache.doris.nereids.trees.expressions.WindowExpression;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
@@ -50,6 +54,7 @@ import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -111,14 +116,16 @@ public class NormalizeAggregate implements 
RewriteRuleFactory, NormalizeToSlot {
     public List<Rule> buildRules() {
         return ImmutableList.of(
                 
logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized))
-                        .then(having -> normalizeAgg(having.child(), 
Optional.of(having)))
+                        .thenApply(ctx -> normalizeAgg(ctx.root.child(), 
Optional.of(ctx.root), ctx.cascadesContext))
                         .toRule(RuleType.NORMALIZE_AGGREGATE),
                 logicalAggregate().whenNot(LogicalAggregate::isNormalized)
-                        .then(aggregate -> normalizeAgg(aggregate, 
Optional.empty()))
+                        .thenApply(ctx -> normalizeAgg(ctx.root, 
Optional.empty(), ctx.cascadesContext))
                         .toRule(RuleType.NORMALIZE_AGGREGATE));
     }
 
-    private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, 
Optional<LogicalHaving<?>> having) {
+    @SuppressWarnings("checkstyle:UnusedLocalVariable")
+    private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, 
Optional<LogicalHaving<?>> having,
+            CascadesContext ctx) {
         // The LogicalAggregate node may contain window agg functions and 
usual agg functions
         // we call window agg functions as window-agg and usual agg functions 
as trivial-agg for short
         // This rule simplify LogicalAggregate node by:
@@ -279,8 +286,10 @@ public class NormalizeAggregate implements 
RewriteRuleFactory, NormalizeToSlot {
         List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
                 groupByExprContext, argsOfAggFuncNeedPushDownContext, 
normalizedAggFuncsToSlotContext);
 
-        // create a parent project node
-        LogicalProject<Plan> project = new LogicalProject<>(upperProjects, 
newAggregate);
+        ExpressionRewriteContext rewriteContext = new 
ExpressionRewriteContext(ctx);
+        LogicalProject<Plan> project = 
eliminateGroupByConstant(groupByExprContext, rewriteContext,
+                normalizedGroupExprs, normalizedAggOutput, bottomProjects, 
aggregate, upperProjects, newAggregate);
+
         // verify project used slots are all coming from agg's output
         List<Slot> slots = collectAllUsedSlots(upperProjects);
         if (!slots.isEmpty()) {
@@ -389,4 +398,93 @@ public class NormalizeAggregate implements 
RewriteRuleFactory, NormalizeToSlot {
             return expr;
         }
     }
+
+    private LogicalProject<Plan> 
eliminateGroupByConstant(NormalizeToSlotContext groupByExprContext,
+            ExpressionRewriteContext rewriteContext, List<Expression> 
normalizedGroupExprs,
+            List<NamedExpression> normalizedAggOutput, Set<NamedExpression> 
bottomProjects,
+            LogicalAggregate<Plan> aggregate, List<NamedExpression> 
upperProjects, LogicalAggregate<?> newAggregate) {
+        // 1. Find the expressions in group by that can be folded into 
constants and build a map(slot, literal)
+        Map<Expression, NormalizeToSlotTriplet> replaceMap = 
groupByExprContext.getNormalizeToSlotMap();
+        if (replaceMap.isEmpty()) {
+            return new LogicalProject<>(upperProjects, newAggregate);
+        }
+        Map<Slot, Expression> slotToLiteral = new HashMap<>();
+        for (Map.Entry<Expression, NormalizeToSlotTriplet> entry : 
replaceMap.entrySet()) {
+            Expression foldExpression = 
FoldConstantRuleOnFE.evaluate(entry.getKey(), rewriteContext);
+            if (foldExpression.isConstant()) {
+                slotToLiteral.put(entry.getValue().remainExpr, foldExpression);
+            }
+        }
+        if (slotToLiteral.isEmpty()) {
+            return new LogicalProject<>(upperProjects, newAggregate);
+        }
+        // 2. Regenerate a group by list without constant key
+        List<Expression> newNormalizedGroupExprs = new ArrayList<>();
+        for (Expression normalizedGroupExpr : normalizedGroupExprs) {
+            if (!slotToLiteral.containsKey((Slot) normalizedGroupExpr)) {
+                newNormalizedGroupExprs.add(normalizedGroupExpr);
+            }
+        }
+        if (newNormalizedGroupExprs.size() == normalizedGroupExprs.size()) {
+            return new LogicalProject<>(upperProjects, newAggregate);
+        }
+        if (newNormalizedGroupExprs.isEmpty()) {
+            Alias tinyInt = new Alias(new TinyIntLiteral((byte) 1));
+            bottomProjects = new HashSet<>(bottomProjects);
+            bottomProjects.add(tinyInt);
+            normalizedAggOutput = new ArrayList<>(normalizedAggOutput);
+            Slot tinyIntSlot = tinyInt.toSlot();
+            normalizedAggOutput.add(tinyIntSlot);
+            newNormalizedGroupExprs.add(tinyIntSlot);
+        }
+        // 3. Replace the agg output expression and delete the constant group 
by key in the output
+        ImmutableList.Builder<NamedExpression> nonConstAggOutput = 
ImmutableList.builder();
+        for (NamedExpression ne : normalizedAggOutput) {
+            if (ne instanceof Alias) {
+                
nonConstAggOutput.add(ExpressionUtils.replaceNameExpression(ne, slotToLiteral));
+                continue;
+            } else if (ne instanceof Slot) {
+                if (!slotToLiteral.containsKey(ne)) {
+                    nonConstAggOutput.add(ne);
+                }
+                continue;
+            }
+            nonConstAggOutput.add(ne);
+        }
+
+        // 4. The constant expression calculation in bottom projects needs to 
be deleted
+        // and put into upperProjects for calculation
+        Plan bottomPlan;
+        if (!bottomProjects.isEmpty()) {
+            ImmutableList.Builder<NamedExpression> builder = 
ImmutableList.builder();
+            for (NamedExpression bottomProject : bottomProjects) {
+                if (!slotToLiteral.containsKey(bottomProject.toSlot())) {
+                    builder.add(bottomProject);
+                }
+            }
+            bottomPlan = new LogicalProject<>(builder.build(), 
aggregate.child());
+        } else {
+            bottomPlan = aggregate.child();
+        }
+        LogicalAggregate<Plan> newAggAfterEliminate = 
aggregate.withNormalized(newNormalizedGroupExprs,
+                nonConstAggOutput.build(), bottomPlan);
+        // 5. This upperProjects needs to add the constant key that was 
deleted in the group by key
+        // and change the reference to the constant key to a constant 
expression
+        ImmutableList.Builder<NamedExpression> newUpperProjects = 
ImmutableList.builder();
+        for (NamedExpression upperProject : upperProjects) {
+            if (upperProject instanceof Alias) {
+                
newUpperProjects.add(ExpressionUtils.replaceNameExpression(upperProject, 
slotToLiteral));
+                continue;
+            } else if (upperProject instanceof Slot) {
+                if (slotToLiteral.containsKey(upperProject)) {
+                    Alias newLiteral = new Alias(upperProject.getExprId(), 
slotToLiteral.get(upperProject),
+                            upperProject.getName());
+                    newUpperProjects.add(newLiteral);
+                    continue;
+                }
+            }
+            newUpperProjects.add(upperProject);
+        }
+        return new LogicalProject<>(newUpperProjects.build(), 
newAggAfterEliminate);
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java
deleted file mode 100644
index c35b983911c..00000000000
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java
+++ /dev/null
@@ -1,165 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.analysis;
-
-import org.apache.doris.catalog.AggregateType;
-import org.apache.doris.catalog.Column;
-import org.apache.doris.catalog.KeysType;
-import org.apache.doris.catalog.OlapTable;
-import org.apache.doris.catalog.PartitionInfo;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.trees.expressions.Add;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
-import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
-import org.apache.doris.nereids.trees.plans.RelationId;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.util.LogicalPlanBuilder;
-import org.apache.doris.nereids.util.MemoPatternMatchSupported;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.thrift.TStorageType;
-
-import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Test;
-
-/** Tests for {@link EliminateGroupByConstant}. */
-class EliminateGroupByConstantTest implements MemoPatternMatchSupported {
-    private static final OlapTable table = new OlapTable(0L, "student",
-            ImmutableList.of(new Column("k1", Type.INT, true, 
AggregateType.NONE, "0", ""),
-                    new Column("k2", Type.INT, false, AggregateType.NONE, "0", 
""),
-                    new Column("k3", Type.INT, true, AggregateType.NONE, "", 
"")),
-            KeysType.PRIMARY_KEYS, new PartitionInfo(), null);
-
-    static {
-        table.setIndexMeta(-1,
-                "t1",
-                table.getFullSchema(),
-                0, 0, (short) 0,
-                TStorageType.COLUMN,
-                KeysType.PRIMARY_KEYS);
-    }
-
-    private static final LogicalOlapScan scan = new 
LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
-    private static final Slot k1 = scan.getOutput().get(0);
-    private static final Slot k2 = scan.getOutput().get(1);
-
-    @Test
-    void testIntegerLiteral() {
-        LogicalPlan aggregate = new LogicalPlanBuilder(scan)
-                .agg(ImmutableList.of(new IntegerLiteral(1), k2),
-                     ImmutableList.of(k1, k2))
-                .build();
-
-        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
-                .applyTopDown(new EliminateGroupByConstant())
-                .applyBottomUp(new CheckAfterRewrite())
-                .matches(
-                        aggregate().when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
-                );
-    }
-
-    @Test
-    void testOtherLiteral() {
-        LogicalPlan aggregate = new LogicalPlanBuilder(scan)
-                .agg(ImmutableList.of(
-                             new StringLiteral("str"), k2),
-                     ImmutableList.of(
-                             new Alias(new StringLiteral("str"), "str"), k1, 
k2))
-                .build();
-
-        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
-                .applyTopDown(new EliminateGroupByConstant())
-                .applyBottomUp(new CheckAfterRewrite())
-                .matches(
-                        aggregate().when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
-                );
-    }
-
-    @Test
-    void testMixedLiteral() {
-        LogicalPlan aggregate = new LogicalPlanBuilder(scan)
-                .agg(ImmutableList.of(
-                             new StringLiteral("str"), k2,
-                             new IntegerLiteral(1),
-                             new IntegerLiteral(2),
-                             new IntegerLiteral(3),
-                             new Add(k1, k2)),
-                     ImmutableList.of(
-                             new Alias(new StringLiteral("str"), "str"),
-                             k2, k1, new Alias(new IntegerLiteral(1), 
"integer")))
-                .build();
-
-        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
-                .applyTopDown(new EliminateGroupByConstant())
-                .applyBottomUp(new CheckAfterRewrite())
-                .matches(
-                        aggregate()
-                                .when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
-                );
-    }
-
-    @Test
-    void testComplexGroupBy() {
-        LogicalPlan aggregate = new LogicalPlanBuilder(scan)
-                .agg(ImmutableList.of(
-                             new IntegerLiteral(1),
-                             new IntegerLiteral(2),
-                             new Add(k1, k2)),
-                     ImmutableList.of(
-                             new Alias(new Max(k1), "max"),
-                             new Alias(new Min(k2), "min"),
-                             new Alias(new Add(k1, k2), "add")))
-                .build();
-
-        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
-                .applyTopDown(new EliminateGroupByConstant())
-                .applyBottomUp(new CheckAfterRewrite())
-                .matches(
-                        aggregate()
-                                .when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2))))
-                );
-    }
-
-    @Test
-    void testOutOfRange() {
-        LogicalPlan aggregate = new LogicalPlanBuilder(scan)
-                .agg(ImmutableList.of(
-                             new StringLiteral("str"), k2,
-                             new IntegerLiteral(1),
-                             new IntegerLiteral(2),
-                             new IntegerLiteral(3),
-                             new IntegerLiteral(5),
-                             new Add(k1, k2)),
-                     ImmutableList.of(
-                                     new Alias(new StringLiteral("str"), 
"str"),
-                                     k2, k1, new Alias(new IntegerLiteral(1), 
"integer")))
-                .build();
-        PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
-                .applyTopDown(new EliminateGroupByConstant())
-                .applyBottomUp(new CheckAfterRewrite())
-                .matches(
-                        aggregate()
-                                .when(agg -> 
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
-                );
-    }
-}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
index 3808fd18428..05494cd4af0 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
@@ -37,23 +37,35 @@ import 
org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.utframe.TestWithFeService;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
-import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInstance;
 
 import java.util.List;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class NormalizeAggregateTest implements MemoPatternMatchSupported {
+public class NormalizeAggregateTest extends TestWithFeService implements 
MemoPatternMatchSupported {
     private LogicalPlan rStudent;
 
-    @BeforeAll
-    public final void beforeAll() {
+    @Override
+    protected void runBeforeAll() throws Exception {
         rStudent = new 
LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), 
PlanConstructor.student,
                 ImmutableList.of());
+        createDatabase("test");
+        connectContext.setDatabase("default_cluster:test");
+        createTables(
+                "CREATE TABLE IF NOT EXISTS t1 (\n"
+                        + "    id int not null,\n"
+                        + "    name char\n"
+                        + ")\n"
+                        + "DUPLICATE KEY(id)\n"
+                        + "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+                        + "PROPERTIES (\"replication_num\" = \"1\")\n"
+        );
+        
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
     }
 
     /*-
@@ -190,4 +202,105 @@ public class NormalizeAggregateTest implements 
MemoPatternMatchSupported {
 
                 );
     }
+
+    // add test for agg eliminate const
+    @Test
+    void testEliminateGroupByConst() {
+        String sql = "select id ,1, 'abc' from t1 group by 1,2,3";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(aggregate -> 
aggregate.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void useTinyIntEliminateGroupByConst() {
+        String sql = "select 1, 'abc' from t1 group by 1,2";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testMixedConstTypes() {
+        String sql = "select id, 1, 'abc', true from t1 group by 1, 2, 3, 4";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testNullConst() {
+        String sql = "select id, NULL from t1 group by 1, 2";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testTwoNullConst() {
+        String sql = "select Null, NULL from t1 group by 1, 2";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testExpressionConst() {
+        String sql = "select id, 1+1, CONCAT('a','b') from t1 group by 1, 2, 
3";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testFunctionCallConst() {
+        String sql = "select id, NOW(), PI() from t1 group by 1, 2, 3";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testDifferentOrder() {
+        String sql = "select 1, id, 'abc' from t1 group by 2, 1, 3";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testDuplicateConst() {
+        String sql = "select id, 1, 1 from t1 group by 1, 2, 3";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1));
+    }
+
+    @Test
+    void testWithAggFunction() {
+        String sql = "select 'abc', 1, COUNT(*) from t1 group by 1, 2";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(logicalAggregate().when(agg ->
+                        agg.getGroupByExpressions().size() == 1
+                                && 
agg.getOutputExpressions().stream().anyMatch(e -> 
e.toString().contains("COUNT"))));
+    }
 }
diff --git 
a/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out
 
b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out
new file mode 100644
index 00000000000..f161b693bd1
Binary files /dev/null and 
b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out
 differ
diff --git 
a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy
 
b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy
new file mode 100644
index 00000000000..3158e2ceded
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("eliminate_constant_gby_key") {
+    sql """DROP TABLE IF EXISTS t1;"""
+    sql """CREATE TABLE t1 (
+        c1 INT,
+        c2 VARCHAR(50),
+        c3 DECIMAL(10,2),
+        c4 DATETIME,
+        c5 BOOLEAN
+    ) distributed by hash(c1) properties("replication_num"="1");"""
+
+    sql """INSERT INTO t1 (c1, c2, c3, c4, c5) VALUES 
+    (1, 'Apple', 10.50, '2023-01-01 10:00:00', true),
+    (2, 'Banana', 20.75, '2023-01-02 11:30:00', false),
+    (3, 'Cherry', 15.25, '2023-01-03 09:15:00', true),
+    (4, 'Date', 30.00, '2023-01-04 14:45:00', false),
+    (5, 'Elderberry', 12.99, '2023-01-05 16:20:00', true),
+    (0, 'Fig', 5.50, '2023-01-06 08:00:00', false),
+    (-1, 'Grape', 8.25, '2023-01-07 12:30:00', true),
+    (NULL, 'Honeydew', NULL, NULL, NULL),
+    (10, 'Iceberg', 18.40, '2023-01-08 13:10:00', false),
+    (100, 'Jackfruit', 42.99, '2023-01-09 17:55:00', true);
+    """
+
+    def funAList = [
+        "TIMESTAMPDIFF(YEAR, NOW(), NOW())",
+        """(TO_DATE(CASE
+        WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19'
+        WHEN (c4 < '2024-01-01') THEN '2026-02-18'
+        ELSE DATE_ADD(c4, INTERVAL 365 DAY) END))"""
+    ]
+
+    def testCases = [
+        [desc: "select funA, c1, funA+c1 group by funA, c1",
+         sql: { funA -> """
+        SELECT 
+            ${funA} AS funA, 
+            c1, 
+            ${funA} + c1 AS 'funA+c1' 
+        FROM t1 
+        GROUP BY ${funA}, c1
+        ORDER BY 1, 2, 3
+        """ }],
+
+        [desc: "select funA, c1, funA+c1 group by funA, c1, funA+c1",
+         sql: { funA -> """
+        SELECT 
+            ${funA} AS funA, 
+            c1, 
+            ${funA} + c1 AS 'funA+c1' 
+        FROM t1 
+        GROUP BY ${funA}, c1, ${funA} + c1
+        ORDER BY 1, 2, 3
+        """ }],
+
+        [desc: "select count(distinct funA), funA, c1 group by funA,c1",
+         sql: { funA -> """
+        SELECT 
+            COUNT(DISTINCT ${funA}) AS 'count(distinct funA)', 
+            ${funA} AS funA, 
+            c1 
+        FROM t1 
+        GROUP BY ${funA}, c1
+        ORDER BY 1, 2, 3
+        """ }],
+
+        [desc: "select count(funA), funA, c1 group by funA, c1",
+         sql: { funA -> """
+        SELECT 
+            COUNT(${funA}) AS 'count(funA)', 
+            ${funA} AS funA, 
+            c1 
+        FROM t1 
+        GROUP BY ${funA}, c1
+        ORDER BY 1, 2, 3
+        """ }],
+
+        [desc: "select COUNT(distinct funA+1), funA, c1 group by funA,c1",
+         sql: { funA -> """
+        SELECT 
+            COUNT(DISTINCT ${funA} + 1) AS 'count(distinct funA+1)', 
+            ${funA} AS funA, 
+            c1 
+        FROM t1 
+        GROUP BY ${funA}, c1
+        ORDER BY 1, 2, 3
+        """ }],
+
+        [desc: "select max(funA+1), funA, c1 group by funA, c1",
+         sql: { funA -> """
+        SELECT 
+            MAX(${funA} + 1) AS 'max(funA+1)', 
+            ${funA} AS funA, 
+            c1 
+        FROM t1 
+        GROUP BY ${funA}, c1
+        ORDER BY 1, 2, 3
+        """ }],
+
+        [desc: "select max(funA+c1), funA, c2 group by funA, c2",
+         sql: { funA -> """
+        SELECT
+            MAX(${funA} + c1) AS 'max(funA+c1)',
+            ${funA} AS funA,
+            c2
+        FROM t1
+        GROUP BY ${funA}, c2
+        ORDER BY 1, 2, 3
+        """ }]
+    ]
+
+    def idx = 1
+    funAList.each { funA ->
+        testCases.each { testCase ->
+            quickTest("test_${idx}", testCase.sql(funA))
+            idx++
+        }
+    }
+
+    qt_gby_key_is_constant_expr_not_literal """
+    SELECT 
+        count(DISTINCT  from_unixtime(1742860744.003242)) AS 'max(distinct 
funA)', 
+        from_unixtime(1742860744.003242) AS funA, 
+        c1 
+    FROM t1 
+    GROUP BY from_unixtime(1742860744.003242), c1
+    order by 1,2,3
+    """
+
+    qt_test_gby_key_is_all_constant """
+    SELECT 
+        count(DISTINCT  from_unixtime(1742860744.003242)) AS 'max(distinct 
funA)', 
+        from_unixtime(1742860744.003242) AS funA, 
+        (TO_DATE(CASE
+        WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19'
+        WHEN (c4 < '2024-01-01') THEN '2026-02-18'
+        ELSE DATE_ADD(c4, INTERVAL 365 DAY) END))
+        c1 
+    FROM t1 
+    GROUP BY from_unixtime(1742860744.003242), (TO_DATE(CASE
+        WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19'
+        WHEN (c4 < '2024-01-01') THEN '2026-02-18'
+        ELSE DATE_ADD(c4, INTERVAL 365 DAY) END)), 'abc'
+    order by 1,2,3,4
+    """
+
+    qt_duplicate_gby_key """
+    SELECT 
+        from_unixtime(1742860744.003242),
+         from_unixtime(1742860744.003242)
+        c1 
+    FROM t1 
+    GROUP BY from_unixtime(1742860744.003242), 
from_unixtime(1742860744.003242),'abc',c1
+    order by 1,2,3,4
+    """
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to