This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch branch-1.2-unstable in repository https://gitbox.apache.org/repos/asf/doris.git
commit dd2e4b04251619c72494807ac8fe064bab85cdf0 Author: morrySnow <[email protected]> AuthorDate: Wed Nov 9 13:43:12 2022 +0800 [fix](Nereids) aggregate disassemble generate error output list on GLOBAL phase aggregate (#14079) we must use localAggregateFunction as key of globalOutputSMap, because we use local output exprs to generate global output in disassembleDistinct --- .../rules/rewrite/AggregateDisassemble.java | 9 ++-- .../rewrite/logical/AggregateDisassembleTest.java | 63 ++++++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java index c4e3db8765..4a632e188b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java @@ -188,11 +188,12 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { continue; } - NamedExpression localOutputExpr = new Alias(aggregateFunction.withAggregateParam( + AggregateFunction localAggregateFunction = aggregateFunction.withAggregateParam( aggregateFunction.getAggregateParam() .withDistinct(false) .withGlobal(false) - ), aggregateFunction.toSql()); + ); + NamedExpression localOutputExpr = new Alias(localAggregateFunction, aggregateFunction.toSql()); List<DataType> inputTypesBeforeDissemble = aggregateFunction.children() .stream() @@ -207,7 +208,9 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { .withChildren(Lists.newArrayList(localOutputExpr.toSlot())); inputSubstitutionMap.put(aggregateFunction, substitutionValue); - globalOutputSMap.put(aggregateFunction, substitutionValue); + // because we use local output exprs to generate global output in disassembleDistinct, + // so we must use localAggregateFunction as key + globalOutputSMap.put(localAggregateFunction, substitutionValue); localOutputExprs.add(localOutputExpr); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java index 2f3e5303fb..af737bdeb3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; @@ -38,11 +39,13 @@ import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import java.util.List; +import java.util.Optional; @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class AggregateDisassembleTest implements PatternMatchSupported { @@ -261,4 +264,64 @@ public class AggregateDisassembleTest implements PatternMatchSupported { 0).getExprId()) ); } + + @Test + public void distinctWithNormalAggregateFunction() { + List<Expression> groupExpressionList = Lists.newArrayList(rStudent.getOutput().get(0).toSlot()); + List<NamedExpression> outputExpressionList = Lists.newArrayList( + new Alias(new Count(AggregateParam.distinctAndGlobal(), rStudent.getOutput().get(2).toSlot()), "c"), + new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); + + // check local: + // id + Expression localOutput0 = rStudent.getOutput().get(0); + // sum + Sum localOutput1 = new Sum(new AggregateParam(false, false, Optional.empty()), rStudent.getOutput().get(0).toSlot()); + // age + Expression localOutput2 = rStudent.getOutput().get(2); + // id + Expression localGroupBy0 = rStudent.getOutput().get(0); + // age + Expression localGroupBy1 = rStudent.getOutput().get(2); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new AggregateDisassemble()) + .matchesFromRoot( + logicalAggregate( + logicalAggregate( + logicalAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) + .when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1)) + .when(agg -> agg.getOutputExpressions().get(2).equals(localOutput2)) + .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy0)) + .when(agg -> agg.getGroupByExpressions().get(1).equals(localGroupBy1)) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0))) + .when(agg -> { + Slot child = agg.child().getOutputExpressions().get(1).toSlot(); + Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof Sum); + return ((Sum) agg.getOutputExpressions().get(1).child(0)).child().equals(child); + }) + .when(agg -> agg.getOutputExpressions().get(2) + .equals(agg.child().getOutputExpressions().get(2))) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0))) + .when(agg -> agg.getGroupByExpressions().get(1) + .equals(agg.child().getOutputExpressions().get(2))) + ).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) + .when(agg -> agg.getOutputExpressions().get(0).child(0) instanceof Count) + .when(agg -> agg.getOutputExpressions().get(1).child(0) instanceof Sum) + .when(agg -> agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get( + 0).getExprId()) + .when(agg -> agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get( + 1).getExprId()) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().child().getOutputExpressions().get(0))) + ); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
