This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-1.2-unstable
in repository https://gitbox.apache.org/repos/asf/doris.git

commit dd2e4b04251619c72494807ac8fe064bab85cdf0
Author: morrySnow <[email protected]>
AuthorDate: Wed Nov 9 13:43:12 2022 +0800

    [fix](Nereids) aggregate disassemble generate error output list on GLOBAL 
phase aggregate (#14079)
    
    we must use localAggregateFunction as key of globalOutputSMap, because we 
use local output exprs to generate global output in disassembleDistinct
---
 .../rules/rewrite/AggregateDisassemble.java        |  9 ++--
 .../rewrite/logical/AggregateDisassembleTest.java  | 63 ++++++++++++++++++++++
 2 files changed, 69 insertions(+), 3 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index c4e3db8765..4a632e188b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -188,11 +188,12 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                     continue;
                 }
 
-                NamedExpression localOutputExpr = new 
Alias(aggregateFunction.withAggregateParam(
+                AggregateFunction localAggregateFunction = 
aggregateFunction.withAggregateParam(
                         aggregateFunction.getAggregateParam()
                                 .withDistinct(false)
                                 .withGlobal(false)
-                ), aggregateFunction.toSql());
+                );
+                NamedExpression localOutputExpr = new 
Alias(localAggregateFunction, aggregateFunction.toSql());
 
                 List<DataType> inputTypesBeforeDissemble = 
aggregateFunction.children()
                         .stream()
@@ -207,7 +208,9 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                         
.withChildren(Lists.newArrayList(localOutputExpr.toSlot()));
 
                 inputSubstitutionMap.put(aggregateFunction, substitutionValue);
-                globalOutputSMap.put(aggregateFunction, substitutionValue);
+                // because we use local output exprs to generate global output 
in disassembleDistinct,
+                // so we must use localAggregateFunction as key
+                globalOutputSMap.put(localAggregateFunction, 
substitutionValue);
                 localOutputExprs.add(localOutputExpr);
             }
         }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index 2f3e5303fb..af737bdeb3 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
@@ -38,11 +39,13 @@ import org.apache.doris.nereids.util.PlanConstructor;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInstance;
 
 import java.util.List;
+import java.util.Optional;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
 public class AggregateDisassembleTest implements PatternMatchSupported {
@@ -261,4 +264,64 @@ public class AggregateDisassembleTest implements 
PatternMatchSupported {
                                         0).getExprId())
                 );
     }
+
+    @Test
+    public void distinctWithNormalAggregateFunction() {
+        List<Expression> groupExpressionList = 
Lists.newArrayList(rStudent.getOutput().get(0).toSlot());
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(
+                new Alias(new Count(AggregateParam.distinctAndGlobal(), 
rStudent.getOutput().get(2).toSlot()), "c"),
+                new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), 
"sum"));
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
+
+        // check local:
+        // id
+        Expression localOutput0 = rStudent.getOutput().get(0);
+        // sum
+        Sum localOutput1 = new Sum(new AggregateParam(false, false, 
Optional.empty()), rStudent.getOutput().get(0).toSlot());
+        // age
+        Expression localOutput2 = rStudent.getOutput().get(2);
+        // id
+        Expression localGroupBy0 = rStudent.getOutput().get(0);
+        // age
+        Expression localGroupBy1 = rStudent.getOutput().get(2);
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new AggregateDisassemble())
+                .matchesFromRoot(
+                        logicalAggregate(
+                                logicalAggregate(
+                                        logicalAggregate()
+                                                .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                                .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
+                                                .when(agg -> 
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+                                                .when(agg -> 
agg.getOutputExpressions().get(2).equals(localOutput2))
+                                                .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
+                                                .when(agg -> 
agg.getGroupByExpressions().get(1).equals(localGroupBy1))
+                                ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
+                                        .when(agg -> 
agg.getOutputExpressions().get(0)
+                                                
.equals(agg.child().getOutputExpressions().get(0)))
+                                        .when(agg -> {
+                                            Slot child = 
agg.child().getOutputExpressions().get(1).toSlot();
+                                            
Assertions.assertTrue(agg.getOutputExpressions().get(1).child(0) instanceof 
Sum);
+                                            return ((Sum) 
agg.getOutputExpressions().get(1).child(0)).child().equals(child);
+                                        })
+                                        .when(agg -> 
agg.getOutputExpressions().get(2)
+                                                
.equals(agg.child().getOutputExpressions().get(2)))
+                                        .when(agg -> 
agg.getGroupByExpressions().get(0)
+                                                
.equals(agg.child().getOutputExpressions().get(0)))
+                                        .when(agg -> 
agg.getGroupByExpressions().get(1)
+                                                
.equals(agg.child().getOutputExpressions().get(2)))
+                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
+                                .when(agg -> agg.getOutputExpressions().size() 
== 2)
+                                .when(agg -> agg.getOutputExpressions().get(0) 
instanceof Alias)
+                                .when(agg -> 
agg.getOutputExpressions().get(0).child(0) instanceof Count)
+                                .when(agg -> 
agg.getOutputExpressions().get(1).child(0) instanceof Sum)
+                                .when(agg -> 
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
+                                        0).getExprId())
+                                .when(agg -> 
agg.getOutputExpressions().get(1).getExprId() == outputExpressionList.get(
+                                        1).getExprId())
+                                .when(agg -> agg.getGroupByExpressions().get(0)
+                                        
.equals(agg.child().child().getOutputExpressions().get(0)))
+                );
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to