This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 49a3bab399e [fix](nereids) fix aggregate function roll up when
expression arguments is not equals (#29256)
49a3bab399e is described below
commit 49a3bab399ef743ad38ef9bf2ec5f71e56edea14
Author: seawinde <[email protected]>
AuthorDate: Wed Jan 3 18:58:18 2024 +0800
[fix](nereids) fix aggregate function roll up when expression arguments is
not equals (#29256)
when aggregate function roll up, we should check the qury and mv function
argument is equal
such as mv def and query sql as following, it should not rewrite success,
because the bitmap_union_basic field augument is
not equal to the `count(distinct case when o_shippriority > 10 and
o_orderkey IN (1, 3) then o_custkey else null end)` field in query
mv def:
> select l_shipdate, o_orderdate, l_partkey, l_suppkey,
> sum(o_totalprice) as sum_total,
> max(o_totalprice) as max_total,
> min(o_totalprice) as min_total,
> count(*) as count_all,
> bitmap_union(to_bitmap(case when o_shippriority > 1 and
o_orderkey IN (1, 3) then o_custkey else null end)) as bitmap_union_basic
> from lineitem
> left join orders on lineitem.l_orderkey = orders.o_orderkey and
l_shipdate = o_orderdate
> group by
> l_shipdate,
> o_orderdate,
> l_partkey,
> l_suppkey;
query sql:
> select t1.l_partkey, t1.l_suppkey, o_orderdate,
> sum(o_totalprice),
> max(o_totalprice),
> min(o_totalprice),
> count(*),
> count(distinct case when o_shippriority > 10 and o_orderkey IN
(1, 3) then o_custkey else null end)
> from (select * from lineitem where l_shipdate = '2023-12-11')
t1
> left join orders on t1.l_orderkey = orders.o_orderkey and
t1.l_shipdate = o_orderdate
> group by
> o_orderdate,
> l_partkey,
> l_suppkey;
---
.../mv/AbstractMaterializedViewAggregateRule.java | 103 ++++++++++++++++-----
.../org/apache/doris/nereids/trees/TreeNode.java | 28 ++++++
.../doris/nereids/trees/expressions/Any.java | 10 ++
.../mv/agg_with_roll_up/aggregate_with_roll_up.out | 6 ++
.../agg_with_roll_up/aggregate_with_roll_up.groovy | 35 +++++++
5 files changed, 158 insertions(+), 24 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
index 685f8a8c3a9..11faaa6a6d3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
@@ -26,6 +26,7 @@ import
org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Any;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -35,11 +36,14 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
+import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
@@ -47,10 +51,11 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
-import java.util.HashMap;
+import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -60,15 +65,18 @@ import java.util.stream.Collectors;
*/
public abstract class AbstractMaterializedViewAggregateRule extends
AbstractMaterializedViewRule {
- protected static final Map<Expression, Expression>
- AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>();
+ // we only support roll up function which has only one argument currently
+ protected static final Multimap<Expression, Expression>
+ AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP =
ArrayListMultimap.create();
protected final String currentClassName = this.getClass().getSimpleName();
private final Logger logger = LogManager.getLogger(this.getClass());
static {
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true,
Any.INSTANCE),
- new BitmapUnion(Any.INSTANCE));
+ new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE,
BigIntType.INSTANCE))));
+ AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true,
Any.INSTANCE),
+ new BitmapUnion(new ToBitmap(Any.INSTANCE)));
}
@Override
@@ -249,17 +257,30 @@ public abstract class
AbstractMaterializedViewAggregateRule extends AbstractMate
return rewrittenAggregate;
}
- // only support sum roll up, support other agg functions later.
- private Function rollup(AggregateFunction queryFunction,
- Expression queryFunctionShuttled,
+ /**
+ * Roll up query aggregate function when query dimension num is less than
mv dimension num,
+ *
+ * @param queryAggregateFunction query aggregate function to roll up.
+ * @param queryAggregateFunctionShuttled query aggregate function shuttled
by lineage.
+ * @param mvExprToMvScanExprQueryBased mv def sql output expressions to mv
result data output mapping.
+ * <p>
+ * Such as query is
+ * select max(a) + 1 from table group by b.
+ * mv is
+ * select max(a) from table group by a, b.
+ * the queryAggregateFunction is max(a),
queryAggregateFunctionShuttled is max(a) + 1
+ * mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) }
+ */
+ private Function rollup(AggregateFunction queryAggregateFunction,
+ Expression queryAggregateFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
- if (!(queryFunction instanceof CouldRollUp)) {
+ if (!(queryAggregateFunction instanceof CouldRollUp)) {
return null;
}
Expression rollupParam = null;
- if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) {
+ if
(mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)) {
// function can rewrite by view
- rollupParam =
mvExprToMvScanExprQueryBased.get(queryFunctionShuttled);
+ rollupParam =
mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled);
} else {
// function can not rewrite by view, try to use complex roll up
param
// eg: query is count(distinct param), mv sql is
bitmap_union(to_bitmap(param))
@@ -267,7 +288,8 @@ public abstract class AbstractMaterializedViewAggregateRule
extends AbstractMate
if (!(mvExprShuttled instanceof Function)) {
continue;
}
- if (isAggregateFunctionEquivalent(queryFunction, (Function)
mvExprShuttled)) {
+ if (isAggregateFunctionEquivalent(queryAggregateFunction,
queryAggregateFunctionShuttled,
+ (Function) mvExprShuttled)) {
rollupParam =
mvExprToMvScanExprQueryBased.get(mvExprShuttled);
}
}
@@ -276,7 +298,7 @@ public abstract class AbstractMaterializedViewAggregateRule
extends AbstractMate
return null;
}
// do roll up
- return ((CouldRollUp) queryFunction).constructRollUp(rollupParam);
+ return ((CouldRollUp)
queryAggregateFunction).constructRollUp(rollupParam);
}
private Pair<Set<? extends Expression>, Set<? extends Expression>>
topPlanSplitToGroupAndFunction(
@@ -347,22 +369,55 @@ public abstract class
AbstractMaterializedViewAggregateRule extends AbstractMate
return true;
}
- private boolean isAggregateFunctionEquivalent(Function queryFunction,
Function viewFunction) {
+ /**
+ * Check the queryFunction is equivalent to view function when function
roll up.
+ * Not only check the function name but also check the argument between
query and view aggregate function.
+ * Such as query is
+ * select count(distinct a) + 1 from table group by b.
+ * mv is
+ * select bitmap_union(to_bitmap(a)) from table group by a, b.
+ * the queryAggregateFunction is count(distinct a),
queryAggregateFunctionShuttled is count(distinct a) + 1
+ * mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) :
MTMVScan(output#0) }
+ * This will check the count(distinct a) in query is equivalent to
bitmap_union(to_bitmap(a)) in mv,
+ * and then check their arguments is equivalent.
+ */
+ private boolean isAggregateFunctionEquivalent(Function queryFunction,
Expression queryFunctionShuttled,
+ Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
- // get query equivalent function
- Expression equivalentFunction = null;
- for (Map.Entry<Expression, Expression> entry :
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.entrySet()) {
- if (entry.getKey().equals(queryFunction)) {
- equivalentFunction = entry.getValue();
+ // check the argument of rollup function is equivalent to view
function or not
+ for (Map.Entry<Expression, Collection<Expression>>
equivalentFunctionEntry :
+ AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) {
+ if (equivalentFunctionEntry.getKey().equals(queryFunction)) {
+ // check is have equivalent function or not
+ for (Expression equivalentFunction :
equivalentFunctionEntry.getValue()) {
+ if (!Any.equals(equivalentFunction, viewFunction)) {
+ continue;
+ }
+ // check param in query function is same as the view
function
+ List<Expression> viewFunctionArguments =
extractViewArguments(equivalentFunction, viewFunction);
+ if (queryFunctionShuttled.getArguments().size() != 1 ||
viewFunctionArguments.size() != 1) {
+ continue;
+ }
+ if
(Objects.equals(queryFunctionShuttled.getArguments().get(0),
viewFunctionArguments.get(0))) {
+ return true;
+ }
+ }
}
}
- // check is have equivalent function or not
- if (equivalentFunction == null) {
- return false;
- }
- // current compare
- return equivalentFunction.equals(viewFunction);
+ return false;
+ }
+
+ /**
+ * Extract the view function arguments by equivalentFunction pattern
+ * Such as equivalentFunction def is bitmap_union(to_bitmap(Any.INSTANCE)),
+ * viewFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2
end))
+ * after extracting, the return argument is: case when a = 5 then 1 else 2
end
+ */
+ private List<Expression> extractViewArguments(Expression
equivalentFunction, Function viewFunction) {
+ Set<Object> exprSetToRemove = equivalentFunction.collectToSet(expr ->
!(expr instanceof Any));
+ return viewFunction.collectFirst(expr ->
+ exprSetToRemove.stream().noneMatch(exprToRemove ->
exprToRemove.equals(expr)));
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index 557ff43b51d..00ac71eaf24 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayDeque;
+import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Set;
@@ -150,6 +151,19 @@ public interface TreeNode<NODE_TYPE extends
TreeNode<NODE_TYPE>> {
return rewriteFunction.apply(rewrittenChildren);
}
+ /**
+ * Foreach treeNode. Top-down traverse implicitly, stop traverse if
satisfy test.
+ * @param func foreach function
+ */
+ default void foreach(Predicate<TreeNode<NODE_TYPE>> func) {
+ boolean valid = func.test(this);
+ if (!valid) {
+ for (NODE_TYPE child : children()) {
+ child.foreach(func);
+ }
+ }
+ }
+
/**
* Foreach treeNode. Top-down traverse implicitly.
* @param func foreach function
@@ -241,6 +255,20 @@ public interface TreeNode<NODE_TYPE extends
TreeNode<NODE_TYPE>> {
return (Set<T>) result.build();
}
+ /**
+ * Collect the nodes that satisfied the predicate firstly.
+ */
+ default <T> List<T> collectFirst(Predicate<TreeNode<NODE_TYPE>> predicate)
{
+ List<TreeNode<NODE_TYPE>> result = new ArrayList<>();
+ foreach(node -> {
+ if (result.isEmpty() && predicate.test(node)) {
+ result.add(node);
+ }
+ return !result.isEmpty();
+ });
+ return (List<T>) ImmutableList.copyOf(result);
+ }
+
/**
* iterate top down and test predicate if contains any instance of the
classes
* @param types classes array
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java
index 43d284bf678..2e4bc745b2a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java
@@ -24,6 +24,7 @@ import
org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.collect.ImmutableList;
import java.util.List;
+import java.util.Objects;
/**
* This represents any expression, it means it equals any expression
@@ -55,6 +56,15 @@ public class Any extends Expression implements
LeafExpression {
return true;
}
+ /**
+ * Equals with direction
+ * Since the equals method in Any is always true, that means Any is equals
to others, but not equal in reverse.
+ * The expression with Any should always be the first argument.
+ */
+ public static boolean equals(Expression expressionWithAny, Expression
target) {
+ return Objects.equals(expressionWithAny, target);
+ }
+
@Override
public int hashCode() {
return 0;
diff --git
a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
index fb223bc661b..334980ed00c 100644
---
a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
+++
b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
@@ -5,6 +5,12 @@
-- !query13_0_after --
3 3 2023-12-11 43.20 43.20 43.20 1 0
+-- !query13_1_before --
+3 3 2023-12-11 43.20 43.20 43.20 1 0
+
+-- !query13_1_after --
+3 3 2023-12-11 43.20 43.20 43.20 1 0
+
-- !query14_0_before --
2 3 2023-12-08 20.00 10.50 9.50 2 0
2 3 2023-12-12 \N \N \N 1 0
diff --git
a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
index fd3c02408d9..e9d1ee76b37 100644
---
a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
+++
b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
@@ -247,6 +247,41 @@ suite("aggregate_with_roll_up") {
sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_0"""
+ def mv13_1 = """
+ select l_shipdate, o_orderdate, l_partkey, l_suppkey,
+ sum(o_totalprice) as sum_total,
+ max(o_totalprice) as max_total,
+ min(o_totalprice) as min_total,
+ count(*) as count_all,
+ bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey
IN (1, 3) then o_custkey else null end)) as bitmap_union_basic
+ from lineitem
+ left join orders on lineitem.l_orderkey = orders.o_orderkey and
l_shipdate = o_orderdate
+ group by
+ l_shipdate,
+ o_orderdate,
+ l_partkey,
+ l_suppkey;
+ """
+ def query13_1 = """
+ select t1.l_partkey, t1.l_suppkey, o_orderdate,
+ sum(o_totalprice),
+ max(o_totalprice),
+ min(o_totalprice),
+ count(*),
+ count(distinct case when o_shippriority > 10 and o_orderkey IN (1,
3) then o_custkey else null end)
+ from (select * from lineitem where l_shipdate = '2023-12-11') t1
+ left join orders on t1.l_orderkey = orders.o_orderkey and
t1.l_shipdate = o_orderdate
+ group by
+ o_orderdate,
+ l_partkey,
+ l_suppkey;
+ """
+ order_qt_query13_1_before "${query13_1}"
+ check_not_match(mv13_1, query13_1, "mv13_1")
+ order_qt_query13_1_after "${query13_1}"
+ sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_1"""
+
+
// filter inside + right + use roll up dimension
def mv14_0 = "select l_shipdate, o_orderdate, l_partkey, l_suppkey, " +
"sum(o_totalprice) as sum_total, " +
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]