This is an automated email from the ASF dual-hosted git repository.

panxiaolei pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 7bda49b506 [Feature](materialized view) support query match mv with 
agg_state on nereids planner (#21067) (#21669)
7bda49b506 is described below

commit 7bda49b5065966eed4405b38665e141110fea16a
Author: Pxl <[email protected]>
AuthorDate: Mon Jul 10 14:43:16 2023 +0800

    [Feature](materialized view) support query match mv with agg_state on 
nereids planner (#21067) (#21669)
    
    * support create mv contain aggstate column
    
    * update
    
    * update
    
    * update
    
    * support query match mv with agg_state on nereids planner
    
    update
    
    * update
    
    * update
---
 be/src/util/timezone_utils.cpp                     |  2 +-
 .../apache/doris/analysis/NativeInsertStmt.java    |  4 +-
 .../mv/SelectMaterializedIndexWithAggregate.java   | 55 ++++++++++++++++++++++
 .../functions/combinator/StateCombinator.java      |  4 ++
 .../data/mv_p0/agg_state/test_agg_state_max_by.out | 22 +++++++--
 .../mv_p0/agg_state/test_agg_state_max_by.groovy   | 41 ++++++++++++----
 6 files changed, 114 insertions(+), 14 deletions(-)

diff --git a/be/src/util/timezone_utils.cpp b/be/src/util/timezone_utils.cpp
index 2b22e52126..e4d19946a7 100644
--- a/be/src/util/timezone_utils.cpp
+++ b/be/src/util/timezone_utils.cpp
@@ -33,7 +33,7 @@ bool TimezoneUtils::find_cctz_time_zone(const std::string& 
timezone, cctz::time_
                                           1)) {
         bool positive = value[0] != '-';
 
-        //Regular expression guarantees hour and minute mush be int
+        //Regular expression guarantees hour and minute must be int
         int hour = std::stoi(value.substr(1, 2).as_string());
         int minute = std::stoi(value.substr(4, 2).as_string());
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java
index 408fcc0952..591f8191f5 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/NativeInsertStmt.java
@@ -448,11 +448,11 @@ public class NativeInsertStmt extends InsertStmt {
                 }
                 targetColumns.add(col);
             }
-            // hll column mush in mentionedColumns
+            // hll column must in mentionedColumns
             for (Column col : targetTable.getBaseSchema()) {
                 if (col.getType().isObjectStored() && 
!mentionedColumns.contains(col.getName())) {
                     throw new AnalysisException(
-                            " object-stored column " + col.getName() + " mush 
in insert into columns");
+                            "object-stored column " + col.getName() + " must 
in insert into columns");
                 }
             }
         }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
index 6c04521e79..7d0b57d649 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
@@ -28,12 +28,14 @@ import 
org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
+import 
org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
@@ -45,6 +47,8 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
+import 
org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapHash;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
@@ -67,6 +71,7 @@ import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.collect.Streams;
 
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -1159,6 +1164,10 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
          */
         @Override
         public Expression visitCount(Count count, RewriteContext context) {
+            Expression result = visitAggregateFunction(count, context);
+            if (result != count) {
+                return result;
+            }
             if (count.isDistinct() && count.arity() == 1) {
                 // count(distinct col) -> 
bitmap_union_count(mv_bitmap_union_col)
                 Optional<Slot> slotOpt = 
ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));
@@ -1225,6 +1234,10 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
          */
         @Override
         public Expression visitBitmapUnionCount(BitmapUnionCount 
bitmapUnionCount, RewriteContext context) {
+            Expression result = visitAggregateFunction(bitmapUnionCount, 
context);
+            if (result != bitmapUnionCount) {
+                return result;
+            }
             if (bitmapUnionCount.child() instanceof ToBitmap) {
                 ToBitmap toBitmap = (ToBitmap) bitmapUnionCount.child();
                 Optional<Slot> slotOpt = 
ExpressionUtils.extractSlotOrCastOnSlot(toBitmap.child());
@@ -1291,6 +1304,10 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
          */
         @Override
         public Expression visitHllUnion(HllUnion hllUnion, RewriteContext 
context) {
+            Expression result = visitAggregateFunction(hllUnion, context);
+            if (result != hllUnion) {
+                return result;
+            }
             if (hllUnion.child() instanceof HllHash) {
                 HllHash hllHash = (HllHash) hllUnion.child();
                 Optional<Slot> slotOpt = 
ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
@@ -1327,6 +1344,10 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
          */
         @Override
         public Expression visitHllUnionAgg(HllUnionAgg hllUnionAgg, 
RewriteContext context) {
+            Expression result = visitAggregateFunction(hllUnionAgg, context);
+            if (result != hllUnionAgg) {
+                return result;
+            }
             if (hllUnionAgg.child() instanceof HllHash) {
                 HllHash hllHash = (HllHash) hllUnionAgg.child();
                 Optional<Slot> slotOpt = 
ExpressionUtils.extractSlotOrCastOnSlot(hllHash.child());
@@ -1363,6 +1384,10 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
          */
         @Override
         public Expression visitNdv(Ndv ndv, RewriteContext context) {
+            Expression result = visitAggregateFunction(ndv, context);
+            if (result != ndv) {
+                return result;
+            }
             Optional<Slot> slotOpt = 
ExpressionUtils.extractSlotOrCastOnSlot(ndv.child(0));
             // ndv on a value column.
             if (slotOpt.isPresent() && 
!context.checkContext.keyNameToColumn.containsKey(
@@ -1391,6 +1416,36 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
             }
             return ndv;
         }
+
+        /**
+         * agg(col) -> agg_merge(mva_generic_aggregation__agg_state(col)) eg: 
max_by(k2,
+         * k3) -> max_by_merge(mva_generic_aggregation__max_by_state(k2, k3))
+         */
+        @Override
+        public Expression visitAggregateFunction(AggregateFunction 
aggregateFunction, RewriteContext context) {
+            String aggStateName = 
normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
+                    AggregateType.GENERIC_AGGREGATION, 
StateCombinator.create(aggregateFunction).toSql()));
+
+            Column mvColumn = 
context.checkContext.scan.getTable().getVisibleColumn(aggStateName);
+            if (mvColumn != null && 
context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
+                Slot aggStateSlot = 
context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream()
+                        .filter(s -> 
aggStateName.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
+                        .orElseThrow(() -> new AnalysisException("cannot find 
agg state slot when select mv"));
+
+                Set<Slot> slots = 
aggregateFunction.collect(SlotReference.class::isInstance);
+                for (Slot slot : slots) {
+                    if 
(!context.checkContext.keyNameToColumn.containsKey(normalizeName(slot.toSql())))
 {
+                        context.exprRewriteMap.slotMap.put(slot, aggStateSlot);
+                        context.exprRewriteMap.projectExprMap.put(slot, 
aggStateSlot);
+                    }
+                }
+
+                MergeCombinator mergeCombinator = new 
MergeCombinator(Arrays.asList(aggStateSlot), aggregateFunction);
+                context.exprRewriteMap.aggFuncMap.put(aggregateFunction, 
mergeCombinator);
+                return mergeCombinator;
+            }
+            return aggregateFunction;
+        }
     }
 
     private List<NamedExpression> replaceAggOutput(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
index 9b97a7afd4..db001a6793 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
@@ -57,6 +57,10 @@ public class StateCombinator extends ScalarFunction
         }).collect(ImmutableList.toImmutableList()));
     }
 
+    public static StateCombinator create(AggregateFunction nested) {
+        return new StateCombinator(nested.getArguments(), nested);
+    }
+
     @Override
     public StateCombinator withChildren(List<Expression> children) {
         return new StateCombinator(children, nested);
diff --git a/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out 
b/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out
index e8082f928a..406e9fa334 100644
--- a/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out
+++ b/regression-test/data/mv_p0/agg_state/test_agg_state_max_by.out
@@ -1,8 +1,24 @@
 -- This file is automatically generated. You should know what you did if you 
want to edit this
 -- !select_star --
 \N     4       \N      d
--4     -4      -4      d
+1      -4      -4      d
+1      -3      \N      c
 1      1       1       a
-2      2       2       b
-3      -3      \N      c
+1      2       2       b
+
+-- !select_mv --
+\N     \N
+1      2
+
+-- !select_mv --
+\N     \N
+1      4
+
+-- !select_mv --
+\N     \N
+1      4
+
+-- !select_mv --
+\N     \N
+1      -4
 
diff --git 
a/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy 
b/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy
index 071f36bb69..7fd7d4fef1 100644
--- a/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy
+++ b/regression-test/suites/mv_p0/agg_state/test_agg_state_max_by.groovy
@@ -36,20 +36,45 @@ suite ("test_agg_state_max_by") {
         """
 
     sql "insert into d_table select 1,1,1,'a';"
-    sql "insert into d_table select 2,2,2,'b';"
-    sql "insert into d_table select 3,-3,null,'c';"
+    sql "insert into d_table select 1,2,2,'b';"
+    sql "insert into d_table select 1,-3,null,'c';"
     sql "insert into d_table(k4,k2) values('d',4);"
 
     createMV("create materialized view k1mb as select k1,max_by(k2,k3) from 
d_table group by k1;")
 
-    sql "insert into d_table select -4,-4,-4,'d';"
+    sql "insert into d_table select 1,-4,-4,'d';"
 
-    qt_select_star "select * from d_table order by k1;"
-/*
+    qt_select_star "select * from d_table order by 1,2;"
     explain {
-        sql("select k1,max_by(k2,k3) from d_table group by k1 order by k1;")
+        sql("select k1,max_by(k2,k3) from d_table group by k1 order by 1,2;")
         contains "(k1mb)"
     }
-    qt_select_mv "select k1,max_by(k2,k3) from d_table group by k1 order by 
k1;"
-*/
+    qt_select_mv "select k1,max_by(k2,k3) from d_table group by k1 order by 
1,2;"
+
+    createMV("create materialized view k1mbcp1 as select 
k1,max_by(k2+k3,abs(k3)) from d_table group by k1;")
+    createMV("create materialized view k1mbcp2 as select k1,max_by(k2+k3,k3) 
from d_table group by k1;")
+    createMV("create materialized view k1mbcp3 as select k1,max_by(k2,abs(k3)) 
from d_table group by k1;")
+
+    sql "insert into d_table(k4,k2) values('d',4);"
+    sql "set enable_nereids_dml = true"
+    sql "insert into d_table(k4,k2) values('d',4);"
+    sql "insert into d_table select 1,-4,-4,'d';"
+
+    explain {
+        sql("select k1,max_by(k2+k3,abs(k3)) from d_table group by k1 order by 
1,2;")
+        contains "(k1mbcp1)"
+    }
+    qt_select_mv "select k1,max_by(k2+k3,k3) from d_table group by k1 order by 
1,2;"
+
+    explain {
+        sql("select k1,max_by(k2+k3,k3) from d_table group by k1 order by 
1,2;")
+        contains "(k1mbcp2)"
+    }
+    qt_select_mv "select k1,max_by(k2+k3,k3) from d_table group by k1 order by 
1,2;"
+
+    explain {
+        sql("select k1,max_by(k2,abs(k3)) from d_table group by k1 order by 
1,2;")
+        contains "(k1mbcp3)"
+    }
+    qt_select_mv "select k1,max_by(k2,abs(k3)) from d_table group by k1 order 
by 1,2;"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to