This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 2499ca6d89ba5c3d950c0bcff3b4d289afb82f7d Author: Xujian Duan <[email protected]> AuthorDate: Tue Jan 23 11:17:53 2024 +0800 [Enhancement](plan) Optimize preagg for aggregate function (#28886) --- .../main/java/org/apache/doris/analysis/Expr.java | 8 ++ .../org/apache/doris/analysis/LiteralExpr.java | 29 ++++++ .../mv/SelectMaterializedIndexWithAggregate.java | 103 +++++++++++++++++++++ .../nereids/trees/expressions/Expression.java | 4 + .../nereids/trees/expressions/literal/Literal.java | 23 +++++ .../apache/doris/planner/SingleNodePlanner.java | 47 ++++++++-- .../rules/rewrite/mv/SelectMvIndexTest.java | 89 ++++++++++++++++++ .../rules/rewrite/mv/SelectRollupIndexTest.java | 2 +- 8 files changed, 294 insertions(+), 11 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index adb7621a346..465c2c947c8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -2627,5 +2627,13 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl expr.replaceSlot(tuple); } } + + public boolean isNullLiteral() { + return this instanceof NullLiteral; + } + + public boolean isZeroLiteral() { + return this instanceof LiteralExpr && ((LiteralExpr) this).isZero(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java index ce89b2fc3c9..0814235f0a3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java @@ -34,8 +34,10 @@ import org.apache.logging.log4j.Logger; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.List; +import java.util.Objects; import java.util.Optional; public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr> { @@ -449,4 +451,31 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr public boolean matchExprs(List<Expr> exprs, SelectStmt stmt, boolean ignoreAlias, TupleDescriptor tuple) { return true; } + + /** whether is ZERO value **/ + public boolean isZero() { + boolean isZero = false; + switch (type.getPrimitiveType()) { + case TINYINT: + case SMALLINT: + case INT: + case BIGINT: + case LARGEINT: + isZero = this.getLongValue() == 0; + break; + case FLOAT: + case DOUBLE: + isZero = this.getDoubleValue() == 0.0f; + break; + case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: + case DECIMAL256: + isZero = Objects.equals(((DecimalLiteral) this).getValue(), BigDecimal.ZERO); + break; + default: + } + return isZero; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index a254e8d67d1..50d48ab8051 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -40,6 +41,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount; @@ -53,8 +55,10 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator; import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap; import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmapWithCheck; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -79,6 +83,7 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Streams; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; @@ -879,6 +884,9 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial if (slotOpt.isPresent() && context.keyNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) { return PreAggStatus.on(); } + if (count.child(0).arity() != 0) { + return checkSubExpressions(count, null, context); + } } return PreAggStatus.off(String.format( "Count distinct is only valid for key columns, but meet %s.", count.toSql())); @@ -963,11 +971,106 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial return PreAggStatus.off(String.format("Aggregate operator don't match, aggregate function: %s" + ", column aggregate type: %s", aggFunc.toSql(), aggType)); } + } else if (!aggFunc.child(0).children().isEmpty()) { + return checkSubExpressions(aggFunc, matchingAggType, ctx); } else { return PreAggStatus.off(String.format("Slot(%s) in %s is neither key column nor value column.", childNameWithFuncName, aggFunc.toSql())); } } + + // check sub expressions in AggregateFunction. + private PreAggStatus checkSubExpressions(AggregateFunction aggFunc, AggregateType matchingAggType, + CheckContext ctx) { + Expression child = aggFunc.child(0); + List<Expression> conditionExps = new ArrayList<>(); + List<Expression> returnExps = new ArrayList<>(); + + // ignore cast + while (child instanceof Cast) { + if (!((Cast) child).getDataType().isNumericType()) { + return PreAggStatus.off(String.format("[%s] is not numeric CAST.", child.toSql())); + } + child = child.child(0); + } + // step 1: extract all condition exprs and return exprs + if (child instanceof If) { + conditionExps.add(child.child(0)); + returnExps.add(child.child(1)); + returnExps.add(child.child(2)); + } else if (child instanceof CaseWhen) { + CaseWhen caseWhen = (CaseWhen) child; + // WHEN THEN + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + conditionExps.add(whenClause.getOperand()); + returnExps.add(whenClause.getResult()); + } + // ELSE + returnExps.add(caseWhen.getDefaultValue().orElse(new NullLiteral())); + } else { + // currently, only IF and CASE WHEN are supported + returnExps.add(child); + } + + // step 2: check condition expressions + for (Expression conditionExp : conditionExps) { + if (!containsAllColumn(conditionExp, ctx.keyNameToColumn.keySet())) { + return PreAggStatus.off(String.format("some columns in condition [%s] is not key.", + conditionExp.toSql())); + } + } + + // step 3: check return expressions + // NOTE: now we just support SUM, MIN, MAX and COUNT DISTINCT + int returnExprValidateNum = 0; + for (Expression returnExp : returnExps) { + // ignore cast in return expr + while (returnExp instanceof Cast) { + returnExp = returnExp.child(0); + } + // now we only check simple return expressions + String exprName = returnExp.getExpressionName(); + if (!returnExp.children().isEmpty()) { + return PreAggStatus.off(String.format("do not support compound expression [%s] in %s.", + returnExp.toSql(), matchingAggType)); + } + if (ctx.keyNameToColumn.containsKey(exprName)) { + if (matchingAggType != AggregateType.MAX && matchingAggType != AggregateType.MIN + && (aggFunc instanceof Count && !aggFunc.isDistinct())) { + return PreAggStatus.off("agg on key column should be MAX, MIN or COUNT DISTINCT."); + } + } + + if (matchingAggType == AggregateType.SUM) { + if ((ctx.valueNameToColumn.containsKey(exprName) + && ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType) + || returnExp.isZeroLiteral() || returnExp.isNullLiteral()) { + returnExprValidateNum++; + } else { + return PreAggStatus.off(String.format("SUM cant preagg for [%s].", aggFunc.toSql())); + } + } else if (matchingAggType == AggregateType.MAX || matchingAggType == AggregateType.MIN) { + if (ctx.keyNameToColumn.containsKey(exprName) || returnExp.isNullLiteral() + || (ctx.valueNameToColumn.containsKey(exprName) + && ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType)) { + returnExprValidateNum++; + } else { + return PreAggStatus.off(String.format("MAX/MIN cant preagg for [%s].", aggFunc.toSql())); + } + } else if (aggFunc.getName().equalsIgnoreCase("COUNT") && aggFunc.isDistinct()) { + if (ctx.keyNameToColumn.containsKey(exprName) + || returnExp.isZeroLiteral() || returnExp.isNullLiteral()) { + returnExprValidateNum++; + } else { + return PreAggStatus.off(String.format("COUNT DISTINCT cant preagg for [%s].", aggFunc.toSql())); + } + } + } + if (returnExprValidateNum == returnExps.size()) { + return PreAggStatus.on(); + } + return PreAggStatus.off(String.format("cant preagg for [%s].", aggFunc.toSql())); + } } private static class CheckContext { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index 1ce12fc0efa..048cae55f3d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -260,6 +260,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements } } + public boolean isZeroLiteral() { + return this instanceof Literal && ((Literal) this).isZero(); + } + public final Expression castTo(DataType targetType) throws AnalysisException { return uncheckedCastTo(targetType); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java index 6ba6461922f..11c292e3a0b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java @@ -366,4 +366,27 @@ public abstract class Literal extends Expression implements LeafExpression, Comp public boolean isStringLikeLiteral() { return dataType.isStringLikeType(); } + + /** whether is ZERO value **/ + public boolean isZero() { + if (isNullLiteral()) { + return false; + } + if (dataType.isSmallIntType() || dataType.isTinyIntType() || dataType.isIntegerType()) { + return getValue().equals(0); + } else if (dataType.isBigIntType()) { + return getValue().equals(0L); + } else if (dataType.isLargeIntType()) { + return getValue().equals(BigInteger.ZERO); + } else if (dataType.isFloatType()) { + return getValue().equals(0.0f); + } else if (dataType.isDoubleType()) { + return getValue().equals(0.0); + } else if (dataType.isDecimalV2Type()) { + return getValue().equals(BigDecimal.ZERO); + } else if (dataType.isDecimalV3Type()) { + return getValue().equals(BigDecimal.ZERO); + } + return false; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java b/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java index 3de53090c9b..06d234c1916 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/SingleNodePlanner.java @@ -28,6 +28,7 @@ import org.apache.doris.analysis.AssertNumRowsElement; import org.apache.doris.analysis.BaseTableRef; import org.apache.doris.analysis.BinaryPredicate; import org.apache.doris.analysis.CaseExpr; +import org.apache.doris.analysis.CaseWhenClause; import org.apache.doris.analysis.CastExpr; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.ExprSubstitutionMap; @@ -77,6 +78,7 @@ import org.apache.doris.thrift.TPushAggOp; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -653,19 +655,38 @@ public class SingleNodePlanner { List<Column> conditionColumns = Lists.newArrayList(); if (!(aggExpr.getChild(0) instanceof SlotRef)) { Expr child = aggExpr.getChild(0); - if ((child instanceof CastExpr) && (child.getChild(0) instanceof SlotRef)) { - if (child.getType().isNumericType() - && child.getChild(0).getType().isNumericType()) { - returnColumns.add(((SlotRef) child.getChild(0)).getDesc().getColumn()); - } else { - turnOffReason = "aggExpr.getChild(0)[" + + // ignore cast + boolean castReturnExprValidate = true; + while (child instanceof CastExpr) { + if (child.getChild(0) instanceof SlotRef) { + if (child.getType().isNumericType() && child.getChild(0).getType().isNumericType()) { + returnColumns.add(((SlotRef) child.getChild(0)).getDesc().getColumn()); + } else { + turnOffReason = "aggExpr.getChild(0)[" + aggExpr.getChild(0).toSql() + "] is not Numeric CastExpr"; - aggExprValidate = false; - break; + castReturnExprValidate = false; + break; + } } - } else if (aggExpr.getChild(0) instanceof CaseExpr) { - CaseExpr caseExpr = (CaseExpr) aggExpr.getChild(0); + child = child.getChild(0); + } + if (!castReturnExprValidate) { + aggExprValidate = false; + break; + } + // convert IF to CASE WHEN. + // For example: + // IF(a > 1, 1, 0) -> CASE WHEN a > 1 THEN 1 ELSE 0 END + if (child instanceof FunctionCallExpr && ((FunctionCallExpr) child) + .getFnName().getFunction().equalsIgnoreCase("IF")) { + Preconditions.checkArgument(child.getChildren().size() == 3); + CaseWhenClause caseWhenClause = new CaseWhenClause(child.getChild(0), child.getChild(1)); + child = new CaseExpr(ImmutableList.of(caseWhenClause), child.getChild(2)); + } + if (child instanceof CaseExpr) { + CaseExpr caseExpr = (CaseExpr) child; List<Expr> conditionExprs = caseExpr.getConditionExprs(); for (Expr conditionExpr : conditionExprs) { List<TupleId> conditionTupleIds = Lists.newArrayList(); @@ -680,8 +701,14 @@ public class SingleNodePlanner { boolean caseReturnExprValidate = true; List<Expr> returnExprs = caseExpr.getReturnExprs(); for (Expr returnExpr : returnExprs) { + // ignore cast in return expr + while (returnExpr instanceof CastExpr) { + returnExpr = returnExpr.getChild(0); + } if (returnExpr instanceof SlotRef) { returnColumns.add(((SlotRef) returnExpr).getDesc().getColumn()); + } else if (returnExpr.isNullLiteral() || returnExpr.isZeroLiteral()) { + // If then expr is NULL or Zero, open the preaggregation } else { turnOffReason = "aggExpr.getChild(0)[" + aggExpr.getChild(0).toSql() + "] is not SlotExpr"; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java index bc5a39cc744..83b969d9f12 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java @@ -1251,4 +1251,93 @@ class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implements MemoP Assertions.assertEquals(secondTableIndexName, scan1.getSelectedIndexName()); }); } + + @Test + public void testSubExpressionsInAggregation() throws Exception { + createTable("CREATE TABLE db1.`test_pre_agg_tbl` (\n" + + " `k1` int,\n" + + " `k2` int,\n" + + " `k3` char,\n" + + " `k4` int,\n" + + " `k5` bigint,\n" + + " `k6` bigint,\n" + + " `v1` int SUM,\n" + + " `v2` bigint SUM,\n" + + " `v3` bigint MAX,\n" + + " `v4` bigint MIN,\n" + + " `v5` float SUM,\n" + + " `v6` double SUM,\n" + + " `v7` decimal SUM\n" + + ") ENGINE=OLAP\n" + + "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`)\n" + + "COMMENT \"OLAP\"\n" + + "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" + + "PROPERTIES (\n" + + "\"replication_num\" = \"1\"\n" + + ");"); + addRollup("alter table db1.test_pre_agg_tbl add rollup test_rollup(k1, k2, k3, v1, v2, v3, v4, v5, v6, v7)"); + + String sql1 = "select sum(case when k1 > 0 then v1 when k1 = 0 then 0 when k1 < 0 then v2 else 0 end)," + + "sum(case when k2 = 1 then 0 else v1 end)," + + "sum(case when k2 = 1 then null else v2 end)," + + "sum(case when k2 = 1 then null else v5 end)," + + "sum(case when k2 = 1 then null else v6 end)," + + "sum(case when k2 = 1 then null else v7 end)" + + "from db1.test_pre_agg_tbl"; + // legacy planner + Assertions.assertTrue(getSQLPlanOrErrorMsg(sql1).contains( + "TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON")); + // nereids planner + PlanChecker.from(connectContext) + .analyze(sql1) + .rewrite() + .matches(logicalOlapScan().when(scan -> { + Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get()); + Assertions.assertTrue(scan.getPreAggStatus().isOn()); + return true; + })); + + String sql2 = "select sum(case when k1 > 0 then v1 else 1 end) from db1.test_pre_agg_tbl"; + // legacy planner + Assertions.assertTrue(getSQLPlanOrErrorMsg(sql2).contains("PREAGGREGATION: OFF")); + // nereids planner + PlanChecker.from(connectContext) + .analyze(sql2) + .rewrite() + .matches(logicalOlapScan().when(scan -> { + Assertions.assertEquals("test_pre_agg_tbl", scan.getSelectedMaterializedIndexName().get()); + Assertions.assertTrue(scan.getPreAggStatus().isOff()); + return true; + })); + + String sql3 = "select max(case when k1 > 0 then v3 else null end),min(case when k1 > 0 then null else v4 end)" + + " from db1.test_pre_agg_tbl"; + // legacy planner + Assertions.assertTrue(getSQLPlanOrErrorMsg(sql3).contains( + "TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON")); + // nereids planner + PlanChecker.from(connectContext) + .analyze(sql3) + .rewrite() + .matches(logicalOlapScan().when(scan -> { + Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get()); + Assertions.assertTrue(scan.getPreAggStatus().isOn()); + return true; + })); + + String sql4 = "select count(distinct case when k1 > 0 then k1 else null end), " + + "count(distinct if(k2 < 0, null, k2)) from db1.test_pre_agg_tbl"; + // legacy planner + Assertions.assertTrue(getSQLPlanOrErrorMsg(sql4).contains( + "TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON")); + // nereids planner + PlanChecker.from(connectContext) + .analyze(sql4) + .rewrite() + .matches(logicalOlapScan().when(scan -> { + Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get()); + Assertions.assertTrue(scan.getPreAggStatus().isOn()); + return true; + })); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java index 9bf2e64ead0..706be618f98 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java @@ -242,7 +242,7 @@ class SelectRollupIndexTest extends BaseMaterializedIndexSelectTest implements M .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); - Assertions.assertEquals("Slot((v1 + 1)) in sum((v1 + 1)) is neither key column nor value column.", + Assertions.assertEquals("do not support compound expression [(v1 + 1)] in SUM.", preAgg.getOffReason()); return true; })); --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
