This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit e413dbec91dfad3610b8e3cae8ae30f6317b9276 Author: starocean999 <[email protected]> AuthorDate: Fri Feb 2 20:37:54 2024 +0800 [fix](nereids)need substitute agg function using agg node's output if it's in order by key (#30704) --- .../nereids/rules/analysis/FillUpMissingSlots.java | 7 ++++++- .../rules/analysis/FillUpMissingSlotsTest.java | 23 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index c8efc1a8917..5688b9d48b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -183,10 +183,14 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { private final List<Expression> groupByExpressions; private final Map<Expression, Slot> substitution = Maps.newHashMap(); private final List<NamedExpression> newOutputSlots = Lists.newArrayList(); + private final Map<Slot, Expression> outputSubstitutionMap; Resolver(Aggregate aggregate) { outputExpressions = aggregate.getOutputExpressions(); groupByExpressions = aggregate.getGroupByExpressions(); + outputSubstitutionMap = outputExpressions.stream().filter(Alias.class::isInstance) + .collect(Collectors.toMap(alias -> alias.toSlot(), alias -> alias.child(0), + (k1, k2) -> k1)); } public void resolve(Expression expression) { @@ -273,7 +277,8 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { } private void generateAliasForNewOutputSlots(Expression expression) { - Alias alias = new Alias(expression); + Expression replacedExpr = ExpressionUtils.replace(expression, outputSubstitutionMap); + Alias alias = new Alias(replacedExpr); newOutputSlots.add(alias); substitution.put(expression, alias.toSlot()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index 0ebaa6f3c38..f270ff4db0e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BigIntType; @@ -513,6 +514,28 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(countStar.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); + sql = "SELECT abs(a1) xx, sum(a2) FROM t1 GROUP BY xx ORDER BY MIN(xx)"; + a1 = new SlotReference( + new ExprId(1), "a1", TinyIntType.INSTANCE, true, + ImmutableList.of("test_resolve_aggregate_functions", "t1") + ); + Alias xx = new Alias(new ExprId(3), new Abs(a1), "xx"); + a2 = new SlotReference( + new ExprId(2), "a2", TinyIntType.INSTANCE, true, + ImmutableList.of("test_resolve_aggregate_functions", "t1") + ); + sumA2 = new Alias(new ExprId(4), new Sum(a2), "sum(a2)"); + + Alias minXX = new Alias(new ExprId(5), new Min(xx.toSlot()), "min(xx)"); + PlanChecker.from(connectContext).analyze(sql).printlnTree().matches(logicalProject( + logicalSort(logicalProject(logicalAggregate(logicalProject(logicalOlapScan()) + .when(FieldChecker.check("projects", Lists.newArrayList(xx, a2, a1)))))) + .when(FieldChecker.check("orderKeys", + ImmutableList + .of(new OrderKey(minXX.toSlot(), true, true))))) + .when(FieldChecker.check("projects", + Lists.newArrayList(xx.toSlot(), + sumA2.toSlot())))); } @Test --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
