This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push:
new ebe302cb7e6 branch-2.1: [fix](nereids) do eliminate constant group by
key in normalizeagg #49589 (#50212)
ebe302cb7e6 is described below
commit ebe302cb7e690199de814fb2634aacde6d70a65e
Author: feiniaofeiafei <[email protected]>
AuthorDate: Thu May 8 18:52:40 2025 +0800
branch-2.1: [fix](nereids) do eliminate constant group by key in
normalizeagg #49589 (#50212)
Cherry-picked from https://github.com/apache/doris/pull/49589
---
.../doris/nereids/jobs/executor/Analyzer.java | 3 -
.../doris/nereids/jobs/executor/Rewriter.java | 2 -
.../nereids/rules/analysis/NormalizeAggregate.java | 108 ++++++++++++-
.../analysis/EliminateGroupByConstantTest.java | 165 --------------------
.../rules/analysis/NormalizeAggregateTest.java | 121 ++++++++++++++-
.../eliminate_constant_gby_key.out | Bin 0 -> 3009 bytes
.../eliminate_constant_gby_key.groovy | 172 +++++++++++++++++++++
7 files changed, 392 insertions(+), 179 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index 5b8e52c0289..24129a10834 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -30,7 +30,6 @@ import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
-import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.HavingToFilter;
@@ -136,8 +135,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
// select SUM(lo_tax) FROM lineorder group by 1;
// errCode = 2, detailMessage = GROUP BY expression must not
contain aggregate functions: sum(lo_tax)
bottomUp(new CheckAnalysis()),
- topDown(new EliminateGroupByConstant()),
-
topDown(new SimplifyAggGroupBy()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 122d6a6cf63..4c712e919ba 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -25,7 +25,6 @@ import org.apache.doris.nereids.rules.RuleType;
import
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
-import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import
org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
@@ -158,7 +157,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(
new EliminateOrderByConstant(),
new EliminateSortUnderSubqueryOrView(),
- new EliminateGroupByConstant(),
// MergeProjects depends on this rule
new LogicalSubQueryAliasToLogicalProject(),
// TODO: we should do expression normalization
after plan normalization
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index e5ebee120a3..4a2e226caae 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -17,9 +17,12 @@
package org.apache.doris.nereids.rules.analysis;
+import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
@@ -35,6 +38,7 @@ import
org.apache.doris.nereids.trees.expressions.WindowExpression;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
@@ -50,6 +54,7 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -111,14 +116,16 @@ public class NormalizeAggregate implements
RewriteRuleFactory, NormalizeToSlot {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized))
- .then(having -> normalizeAgg(having.child(),
Optional.of(having)))
+ .thenApply(ctx -> normalizeAgg(ctx.root.child(),
Optional.of(ctx.root), ctx.cascadesContext))
.toRule(RuleType.NORMALIZE_AGGREGATE),
logicalAggregate().whenNot(LogicalAggregate::isNormalized)
- .then(aggregate -> normalizeAgg(aggregate,
Optional.empty()))
+ .thenApply(ctx -> normalizeAgg(ctx.root,
Optional.empty(), ctx.cascadesContext))
.toRule(RuleType.NORMALIZE_AGGREGATE));
}
- private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate,
Optional<LogicalHaving<?>> having) {
+ @SuppressWarnings("checkstyle:UnusedLocalVariable")
+ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate,
Optional<LogicalHaving<?>> having,
+ CascadesContext ctx) {
// The LogicalAggregate node may contain window agg functions and
usual agg functions
// we call window agg functions as window-agg and usual agg functions
as trivial-agg for short
// This rule simplify LogicalAggregate node by:
@@ -279,8 +286,10 @@ public class NormalizeAggregate implements
RewriteRuleFactory, NormalizeToSlot {
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
groupByExprContext, argsOfAggFuncNeedPushDownContext,
normalizedAggFuncsToSlotContext);
- // create a parent project node
- LogicalProject<Plan> project = new LogicalProject<>(upperProjects,
newAggregate);
+ ExpressionRewriteContext rewriteContext = new
ExpressionRewriteContext(ctx);
+ LogicalProject<Plan> project =
eliminateGroupByConstant(groupByExprContext, rewriteContext,
+ normalizedGroupExprs, normalizedAggOutput, bottomProjects,
aggregate, upperProjects, newAggregate);
+
// verify project used slots are all coming from agg's output
List<Slot> slots = collectAllUsedSlots(upperProjects);
if (!slots.isEmpty()) {
@@ -389,4 +398,93 @@ public class NormalizeAggregate implements
RewriteRuleFactory, NormalizeToSlot {
return expr;
}
}
+
+ private LogicalProject<Plan>
eliminateGroupByConstant(NormalizeToSlotContext groupByExprContext,
+ ExpressionRewriteContext rewriteContext, List<Expression>
normalizedGroupExprs,
+ List<NamedExpression> normalizedAggOutput, Set<NamedExpression>
bottomProjects,
+ LogicalAggregate<Plan> aggregate, List<NamedExpression>
upperProjects, LogicalAggregate<?> newAggregate) {
+ // 1. Find the expressions in group by that can be folded into
constants and build a map(slot, literal)
+ Map<Expression, NormalizeToSlotTriplet> replaceMap =
groupByExprContext.getNormalizeToSlotMap();
+ if (replaceMap.isEmpty()) {
+ return new LogicalProject<>(upperProjects, newAggregate);
+ }
+ Map<Slot, Expression> slotToLiteral = new HashMap<>();
+ for (Map.Entry<Expression, NormalizeToSlotTriplet> entry :
replaceMap.entrySet()) {
+ Expression foldExpression =
FoldConstantRuleOnFE.evaluate(entry.getKey(), rewriteContext);
+ if (foldExpression.isConstant()) {
+ slotToLiteral.put(entry.getValue().remainExpr, foldExpression);
+ }
+ }
+ if (slotToLiteral.isEmpty()) {
+ return new LogicalProject<>(upperProjects, newAggregate);
+ }
+ // 2. Regenerate a group by list without constant key
+ List<Expression> newNormalizedGroupExprs = new ArrayList<>();
+ for (Expression normalizedGroupExpr : normalizedGroupExprs) {
+ if (!slotToLiteral.containsKey((Slot) normalizedGroupExpr)) {
+ newNormalizedGroupExprs.add(normalizedGroupExpr);
+ }
+ }
+ if (newNormalizedGroupExprs.size() == normalizedGroupExprs.size()) {
+ return new LogicalProject<>(upperProjects, newAggregate);
+ }
+ if (newNormalizedGroupExprs.isEmpty()) {
+ Alias tinyInt = new Alias(new TinyIntLiteral((byte) 1));
+ bottomProjects = new HashSet<>(bottomProjects);
+ bottomProjects.add(tinyInt);
+ normalizedAggOutput = new ArrayList<>(normalizedAggOutput);
+ Slot tinyIntSlot = tinyInt.toSlot();
+ normalizedAggOutput.add(tinyIntSlot);
+ newNormalizedGroupExprs.add(tinyIntSlot);
+ }
+ // 3. Replace the agg output expression and delete the constant group
by key in the output
+ ImmutableList.Builder<NamedExpression> nonConstAggOutput =
ImmutableList.builder();
+ for (NamedExpression ne : normalizedAggOutput) {
+ if (ne instanceof Alias) {
+
nonConstAggOutput.add(ExpressionUtils.replaceNameExpression(ne, slotToLiteral));
+ continue;
+ } else if (ne instanceof Slot) {
+ if (!slotToLiteral.containsKey(ne)) {
+ nonConstAggOutput.add(ne);
+ }
+ continue;
+ }
+ nonConstAggOutput.add(ne);
+ }
+
+ // 4. The constant expression calculation in bottom projects needs to
be deleted
+ // and put into upperProjects for calculation
+ Plan bottomPlan;
+ if (!bottomProjects.isEmpty()) {
+ ImmutableList.Builder<NamedExpression> builder =
ImmutableList.builder();
+ for (NamedExpression bottomProject : bottomProjects) {
+ if (!slotToLiteral.containsKey(bottomProject.toSlot())) {
+ builder.add(bottomProject);
+ }
+ }
+ bottomPlan = new LogicalProject<>(builder.build(),
aggregate.child());
+ } else {
+ bottomPlan = aggregate.child();
+ }
+ LogicalAggregate<Plan> newAggAfterEliminate =
aggregate.withNormalized(newNormalizedGroupExprs,
+ nonConstAggOutput.build(), bottomPlan);
+ // 5. This upperProjects needs to add the constant key that was
deleted in the group by key
+ // and change the reference to the constant key to a constant
expression
+ ImmutableList.Builder<NamedExpression> newUpperProjects =
ImmutableList.builder();
+ for (NamedExpression upperProject : upperProjects) {
+ if (upperProject instanceof Alias) {
+
newUpperProjects.add(ExpressionUtils.replaceNameExpression(upperProject,
slotToLiteral));
+ continue;
+ } else if (upperProject instanceof Slot) {
+ if (slotToLiteral.containsKey(upperProject)) {
+ Alias newLiteral = new Alias(upperProject.getExprId(),
slotToLiteral.get(upperProject),
+ upperProject.getName());
+ newUpperProjects.add(newLiteral);
+ continue;
+ }
+ }
+ newUpperProjects.add(upperProject);
+ }
+ return new LogicalProject<>(newUpperProjects.build(),
newAggAfterEliminate);
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java
deleted file mode 100644
index c35b983911c..00000000000
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java
+++ /dev/null
@@ -1,165 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.analysis;
-
-import org.apache.doris.catalog.AggregateType;
-import org.apache.doris.catalog.Column;
-import org.apache.doris.catalog.KeysType;
-import org.apache.doris.catalog.OlapTable;
-import org.apache.doris.catalog.PartitionInfo;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.trees.expressions.Add;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
-import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
-import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
-import org.apache.doris.nereids.trees.plans.RelationId;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.util.LogicalPlanBuilder;
-import org.apache.doris.nereids.util.MemoPatternMatchSupported;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanChecker;
-import org.apache.doris.thrift.TStorageType;
-
-import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Test;
-
-/** Tests for {@link EliminateGroupByConstant}. */
-class EliminateGroupByConstantTest implements MemoPatternMatchSupported {
- private static final OlapTable table = new OlapTable(0L, "student",
- ImmutableList.of(new Column("k1", Type.INT, true,
AggregateType.NONE, "0", ""),
- new Column("k2", Type.INT, false, AggregateType.NONE, "0",
""),
- new Column("k3", Type.INT, true, AggregateType.NONE, "",
"")),
- KeysType.PRIMARY_KEYS, new PartitionInfo(), null);
-
- static {
- table.setIndexMeta(-1,
- "t1",
- table.getFullSchema(),
- 0, 0, (short) 0,
- TStorageType.COLUMN,
- KeysType.PRIMARY_KEYS);
- }
-
- private static final LogicalOlapScan scan = new
LogicalOlapScan(RelationId.createGenerator().getNextId(), table);
- private static final Slot k1 = scan.getOutput().get(0);
- private static final Slot k2 = scan.getOutput().get(1);
-
- @Test
- void testIntegerLiteral() {
- LogicalPlan aggregate = new LogicalPlanBuilder(scan)
- .agg(ImmutableList.of(new IntegerLiteral(1), k2),
- ImmutableList.of(k1, k2))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
- .applyTopDown(new EliminateGroupByConstant())
- .applyBottomUp(new CheckAfterRewrite())
- .matches(
- aggregate().when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
- );
- }
-
- @Test
- void testOtherLiteral() {
- LogicalPlan aggregate = new LogicalPlanBuilder(scan)
- .agg(ImmutableList.of(
- new StringLiteral("str"), k2),
- ImmutableList.of(
- new Alias(new StringLiteral("str"), "str"), k1,
k2))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
- .applyTopDown(new EliminateGroupByConstant())
- .applyBottomUp(new CheckAfterRewrite())
- .matches(
- aggregate().when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2)))
- );
- }
-
- @Test
- void testMixedLiteral() {
- LogicalPlan aggregate = new LogicalPlanBuilder(scan)
- .agg(ImmutableList.of(
- new StringLiteral("str"), k2,
- new IntegerLiteral(1),
- new IntegerLiteral(2),
- new IntegerLiteral(3),
- new Add(k1, k2)),
- ImmutableList.of(
- new Alias(new StringLiteral("str"), "str"),
- k2, k1, new Alias(new IntegerLiteral(1),
"integer")))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
- .applyTopDown(new EliminateGroupByConstant())
- .applyBottomUp(new CheckAfterRewrite())
- .matches(
- aggregate()
- .when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
- );
- }
-
- @Test
- void testComplexGroupBy() {
- LogicalPlan aggregate = new LogicalPlanBuilder(scan)
- .agg(ImmutableList.of(
- new IntegerLiteral(1),
- new IntegerLiteral(2),
- new Add(k1, k2)),
- ImmutableList.of(
- new Alias(new Max(k1), "max"),
- new Alias(new Min(k2), "min"),
- new Alias(new Add(k1, k2), "add")))
- .build();
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
- .applyTopDown(new EliminateGroupByConstant())
- .applyBottomUp(new CheckAfterRewrite())
- .matches(
- aggregate()
- .when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2))))
- );
- }
-
- @Test
- void testOutOfRange() {
- LogicalPlan aggregate = new LogicalPlanBuilder(scan)
- .agg(ImmutableList.of(
- new StringLiteral("str"), k2,
- new IntegerLiteral(1),
- new IntegerLiteral(2),
- new IntegerLiteral(3),
- new IntegerLiteral(5),
- new Add(k1, k2)),
- ImmutableList.of(
- new Alias(new StringLiteral("str"),
"str"),
- k2, k1, new Alias(new IntegerLiteral(1),
"integer")))
- .build();
- PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate)
- .applyTopDown(new EliminateGroupByConstant())
- .applyBottomUp(new CheckAfterRewrite())
- .matches(
- aggregate()
- .when(agg ->
agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2))))
- );
- }
-}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
index 3808fd18428..05494cd4af0 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
@@ -37,23 +37,35 @@ import
org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
+import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
-import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class NormalizeAggregateTest implements MemoPatternMatchSupported {
+public class NormalizeAggregateTest extends TestWithFeService implements
MemoPatternMatchSupported {
private LogicalPlan rStudent;
- @BeforeAll
- public final void beforeAll() {
+ @Override
+ protected void runBeforeAll() throws Exception {
rStudent = new
LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
PlanConstructor.student,
ImmutableList.of());
+ createDatabase("test");
+ connectContext.setDatabase("default_cluster:test");
+ createTables(
+ "CREATE TABLE IF NOT EXISTS t1 (\n"
+ + " id int not null,\n"
+ + " name char\n"
+ + ")\n"
+ + "DUPLICATE KEY(id)\n"
+ + "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ + "PROPERTIES (\"replication_num\" = \"1\")\n"
+ );
+
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
}
/*-
@@ -190,4 +202,105 @@ public class NormalizeAggregateTest implements
MemoPatternMatchSupported {
);
}
+
+ // add test for agg eliminate const
+ @Test
+ void testEliminateGroupByConst() {
+ String sql = "select id ,1, 'abc' from t1 group by 1,2,3";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(aggregate ->
aggregate.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void useTinyIntEliminateGroupByConst() {
+ String sql = "select 1, 'abc' from t1 group by 1,2";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testMixedConstTypes() {
+ String sql = "select id, 1, 'abc', true from t1 group by 1, 2, 3, 4";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testNullConst() {
+ String sql = "select id, NULL from t1 group by 1, 2";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testTwoNullConst() {
+ String sql = "select Null, NULL from t1 group by 1, 2";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testExpressionConst() {
+ String sql = "select id, 1+1, CONCAT('a','b') from t1 group by 1, 2,
3";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testFunctionCallConst() {
+ String sql = "select id, NOW(), PI() from t1 group by 1, 2, 3";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testDifferentOrder() {
+ String sql = "select 1, id, 'abc' from t1 group by 2, 1, 3";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testDuplicateConst() {
+ String sql = "select id, 1, 1 from t1 group by 1, 2, 3";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1));
+ }
+
+ @Test
+ void testWithAggFunction() {
+ String sql = "select 'abc', 1, COUNT(*) from t1 group by 1, 2";
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(logicalAggregate().when(agg ->
+ agg.getGroupByExpressions().size() == 1
+ &&
agg.getOutputExpressions().stream().anyMatch(e ->
e.toString().contains("COUNT"))));
+ }
}
diff --git
a/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out
b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out
new file mode 100644
index 00000000000..f161b693bd1
Binary files /dev/null and
b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out
differ
diff --git
a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy
b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy
new file mode 100644
index 00000000000..3158e2ceded
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("eliminate_constant_gby_key") {
+ sql """DROP TABLE IF EXISTS t1;"""
+ sql """CREATE TABLE t1 (
+ c1 INT,
+ c2 VARCHAR(50),
+ c3 DECIMAL(10,2),
+ c4 DATETIME,
+ c5 BOOLEAN
+ ) distributed by hash(c1) properties("replication_num"="1");"""
+
+ sql """INSERT INTO t1 (c1, c2, c3, c4, c5) VALUES
+ (1, 'Apple', 10.50, '2023-01-01 10:00:00', true),
+ (2, 'Banana', 20.75, '2023-01-02 11:30:00', false),
+ (3, 'Cherry', 15.25, '2023-01-03 09:15:00', true),
+ (4, 'Date', 30.00, '2023-01-04 14:45:00', false),
+ (5, 'Elderberry', 12.99, '2023-01-05 16:20:00', true),
+ (0, 'Fig', 5.50, '2023-01-06 08:00:00', false),
+ (-1, 'Grape', 8.25, '2023-01-07 12:30:00', true),
+ (NULL, 'Honeydew', NULL, NULL, NULL),
+ (10, 'Iceberg', 18.40, '2023-01-08 13:10:00', false),
+ (100, 'Jackfruit', 42.99, '2023-01-09 17:55:00', true);
+ """
+
+ def funAList = [
+ "TIMESTAMPDIFF(YEAR, NOW(), NOW())",
+ """(TO_DATE(CASE
+ WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19'
+ WHEN (c4 < '2024-01-01') THEN '2026-02-18'
+ ELSE DATE_ADD(c4, INTERVAL 365 DAY) END))"""
+ ]
+
+ def testCases = [
+ [desc: "select funA, c1, funA+c1 group by funA, c1",
+ sql: { funA -> """
+ SELECT
+ ${funA} AS funA,
+ c1,
+ ${funA} + c1 AS 'funA+c1'
+ FROM t1
+ GROUP BY ${funA}, c1
+ ORDER BY 1, 2, 3
+ """ }],
+
+ [desc: "select funA, c1, funA+c1 group by funA, c1, funA+c1",
+ sql: { funA -> """
+ SELECT
+ ${funA} AS funA,
+ c1,
+ ${funA} + c1 AS 'funA+c1'
+ FROM t1
+ GROUP BY ${funA}, c1, ${funA} + c1
+ ORDER BY 1, 2, 3
+ """ }],
+
+ [desc: "select count(distinct funA), funA, c1 group by funA,c1",
+ sql: { funA -> """
+ SELECT
+ COUNT(DISTINCT ${funA}) AS 'count(distinct funA)',
+ ${funA} AS funA,
+ c1
+ FROM t1
+ GROUP BY ${funA}, c1
+ ORDER BY 1, 2, 3
+ """ }],
+
+ [desc: "select count(funA), funA, c1 group by funA, c1",
+ sql: { funA -> """
+ SELECT
+ COUNT(${funA}) AS 'count(funA)',
+ ${funA} AS funA,
+ c1
+ FROM t1
+ GROUP BY ${funA}, c1
+ ORDER BY 1, 2, 3
+ """ }],
+
+ [desc: "select COUNT(distinct funA+1), funA, c1 group by funA,c1",
+ sql: { funA -> """
+ SELECT
+ COUNT(DISTINCT ${funA} + 1) AS 'count(distinct funA+1)',
+ ${funA} AS funA,
+ c1
+ FROM t1
+ GROUP BY ${funA}, c1
+ ORDER BY 1, 2, 3
+ """ }],
+
+ [desc: "select max(funA+1), funA, c1 group by funA, c1",
+ sql: { funA -> """
+ SELECT
+ MAX(${funA} + 1) AS 'max(funA+1)',
+ ${funA} AS funA,
+ c1
+ FROM t1
+ GROUP BY ${funA}, c1
+ ORDER BY 1, 2, 3
+ """ }],
+
+ [desc: "select max(funA+c1), funA, c2 group by funA, c2",
+ sql: { funA -> """
+ SELECT
+ MAX(${funA} + c1) AS 'max(funA+c1)',
+ ${funA} AS funA,
+ c2
+ FROM t1
+ GROUP BY ${funA}, c2
+ ORDER BY 1, 2, 3
+ """ }]
+ ]
+
+ def idx = 1
+ funAList.each { funA ->
+ testCases.each { testCase ->
+ quickTest("test_${idx}", testCase.sql(funA))
+ idx++
+ }
+ }
+
+ qt_gby_key_is_constant_expr_not_literal """
+ SELECT
+ count(DISTINCT from_unixtime(1742860744.003242)) AS 'max(distinct
funA)',
+ from_unixtime(1742860744.003242) AS funA,
+ c1
+ FROM t1
+ GROUP BY from_unixtime(1742860744.003242), c1
+ order by 1,2,3
+ """
+
+ qt_test_gby_key_is_all_constant """
+ SELECT
+ count(DISTINCT from_unixtime(1742860744.003242)) AS 'max(distinct
funA)',
+ from_unixtime(1742860744.003242) AS funA,
+ (TO_DATE(CASE
+ WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19'
+ WHEN (c4 < '2024-01-01') THEN '2026-02-18'
+ ELSE DATE_ADD(c4, INTERVAL 365 DAY) END))
+ c1
+ FROM t1
+ GROUP BY from_unixtime(1742860744.003242), (TO_DATE(CASE
+ WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19'
+ WHEN (c4 < '2024-01-01') THEN '2026-02-18'
+ ELSE DATE_ADD(c4, INTERVAL 365 DAY) END)), 'abc'
+ order by 1,2,3,4
+ """
+
+ qt_duplicate_gby_key """
+ SELECT
+ from_unixtime(1742860744.003242),
+ from_unixtime(1742860744.003242)
+ c1
+ FROM t1
+ GROUP BY from_unixtime(1742860744.003242),
from_unixtime(1742860744.003242),'abc',c1
+ order by 1,2,3,4
+ """
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]