This is an automated email from the ASF dual-hosted git repository.
kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push:
new 47f95daaa71 [refactor](Nereids) refactor infer predicate rule to avoid
lost cast (#25637) (#25930)
47f95daaa71 is described below
commit 47f95daaa71911a4811a1fe63b508c3973b47aff
Author: morrySnow <[email protected]>
AuthorDate: Thu Oct 26 11:46:04 2023 +0800
[refactor](Nereids) refactor infer predicate rule to avoid lost cast
(#25637) (#25930)
pick from master
PR: #25637
commit id: ae66464d6b039e34771fa330ea85194849d43c43
extract slot and literal in comparison predicate. infer new one by equals
predicates.
use TypeCoercion to add cast on new comparison predicate to ensure it is
correct.
This reverts "[Fix](Nereids) Add cast comparison with slot reference when
inferring predicate (#21171)"
commit 58f2593ba1b65713e7b3c1ed39fc84be8cc3ff2c.
---
.../rules/rewrite/PredicatePropagation.java | 223 ++++++++++++++++-----
.../apache/doris/nereids/util/ExpressionUtils.java | 29 ---
.../infer_predicate/infer_predicate.groovy | 2 +-
3 files changed, 170 insertions(+), 84 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index cc45952817a..41550bce3be 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -17,19 +17,28 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.nereids.parser.NereidsParser;
+import
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
-import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DateTimeType;
+import org.apache.doris.nereids.types.DateTimeV2Type;
+import org.apache.doris.nereids.types.DateType;
+import org.apache.doris.nereids.types.DateV2Type;
+import org.apache.doris.nereids.types.coercion.CharacterType;
+import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
-import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.Sets;
-import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -40,19 +49,61 @@ import java.util.stream.Collectors;
*/
public class PredicatePropagation {
+ private enum InferType {
+ NONE(null),
+ INTEGRAL(IntegralType.class),
+ STRING(CharacterType.class),
+ DATE(DateLikeType.class),
+ OTHER(DataType.class)
+ ;
+
+ private final Class<? extends DataType> superClazz;
+
+ InferType(Class<? extends DataType> superClazz) {
+ this.superClazz = superClazz;
+ }
+ }
+
+ private class ComparisonInferInfo {
+
+ public final InferType inferType;
+ public final Optional<Expression> left;
+ public final Optional<Expression> right;
+ public final ComparisonPredicate comparisonPredicate;
+
+ public ComparisonInferInfo(InferType inferType,
+ Optional<Expression> left, Optional<Expression> right,
+ ComparisonPredicate comparisonPredicate) {
+ this.inferType = inferType;
+ this.left = left;
+ this.right = right;
+ this.comparisonPredicate = comparisonPredicate;
+ }
+ }
+
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
- if (canEquivalentInfer(predicate)) {
- List<Expression> newInferred = predicates.stream()
- .filter(p -> !p.equals(predicate))
- .map(p -> doInfer(predicate, p))
- .collect(Collectors.toList());
- inferred.addAll(newInferred);
+ if (!(predicate instanceof ComparisonPredicate)) {
+ continue;
+ }
+ ComparisonInferInfo equalInfo =
getEquivalentInferInfo((ComparisonPredicate) predicate);
+ if (equalInfo.inferType == InferType.NONE) {
+ continue;
}
+ Set<Expression> newInferred = predicates.stream()
+ .filter(ComparisonPredicate.class::isInstance)
+ .filter(p -> !p.equals(predicate))
+ .map(ComparisonPredicate.class::cast)
+ .map(this::inferInferInfo)
+ .filter(predicateInfo -> predicateInfo.inferType !=
InferType.NONE)
+ .map(predicateInfo -> doInfer(equalInfo, predicateInfo))
+ .filter(Objects::nonNull)
+ .collect(Collectors.toSet());
+ inferred.addAll(newInferred);
}
inferred.removeAll(predicates);
return inferred;
@@ -64,64 +115,128 @@ public class PredicatePropagation {
* TODO: We should determine whether `expression` satisfies the condition
for replacement
* eg: Satisfy `expression` is non-deterministic
*/
- private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression
expression) {
- return expression.accept(new DefaultExpressionRewriter<Void>() {
+ private Expression doInfer(ComparisonInferInfo equalInfo,
ComparisonInferInfo predicateInfo) {
+ Expression predicateLeft = predicateInfo.left.get();
+ Expression predicateRight = predicateInfo.right.get();
+ Expression equalLeft = equalInfo.left.get();
+ Expression equalRight = equalInfo.right.get();
+ Expression newLeft = inferOneSide(predicateLeft, equalLeft,
equalRight);
+ Expression newRight = inferOneSide(predicateRight, equalLeft,
equalRight);
+ if (newLeft == null || newRight == null) {
+ return null;
+ }
+ ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo
+ .comparisonPredicate.withChildren(newLeft, newRight);
+ return SimplifyComparisonPredicate.INSTANCE
+
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate, newLeft,
newRight), null);
+ }
- @Override
- public Expression visit(Expression expr, Void context) {
- return expr;
+ private Expression inferOneSide(Expression predicateOneSide, Expression
equalLeft, Expression equalRight) {
+ if (predicateOneSide instanceof SlotReference) {
+ if (predicateOneSide.equals(equalLeft)) {
+ return equalRight;
+ } else if (predicateOneSide.equals(equalRight)) {
+ return equalLeft;
}
-
- @Override
- public Expression visitComparisonPredicate(ComparisonPredicate cp,
Void context) {
- // we need to get expression covered by cast, because we want
to infer different datatype
- if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left())
&& (cp.right().isConstant())) {
- return replaceSlot(cp,
ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
- } else if
(ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) &&
cp.left().isConstant()) {
- return replaceSlot(cp,
ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
- }
- return super.visit(cp, context);
+ } else if (predicateOneSide.isConstant()) {
+ if (predicateOneSide instanceof IntegerLikeLiteral) {
+ return new
NereidsParser().parseExpression(((IntegerLikeLiteral)
predicateOneSide).toSql());
+ } else {
+ return predicateOneSide;
}
+ }
+ return null;
+ }
- private boolean isDataTypeValid(DataType originDataType,
Expression expr) {
- if ((leftSlotEqualToRightSlot.child(0).getDataType()
instanceof IntegralType)
- && (leftSlotEqualToRightSlot.child(1).getDataType()
instanceof IntegralType)
- && (originDataType instanceof IntegralType)) {
- // infer filter can not be lower than original datatype,
or dataset would be wrong
- if (!((IntegralType) originDataType).widerThan(
- (IntegralType)
leftSlotEqualToRightSlot.child(0).getDataType())
- && !((IntegralType)
originDataType).widerThan(
- (IntegralType)
leftSlotEqualToRightSlot.child(1).getDataType())) {
- return true;
+ private Optional<Expression> validForInfer(Expression expression,
InferType inferType) {
+ if
(!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
+ return Optional.empty();
+ }
+ if (expression instanceof SlotReference || expression.isConstant()) {
+ return Optional.of(expression);
+ }
+ if (inferType == InferType.INTEGRAL) {
+ if (expression instanceof Cast) {
+ // avoid cast from wider type to narrower type, such as
cast(int as smallint)
+ // IntegralType dataType = (IntegralType)
expression.getDataType();
+ // DataType childType = ((Cast)
expression).child().getDataType();
+ // if (childType instanceof IntegralType &&
dataType.widerThan((IntegralType) childType)) {
+ // return validForInfer(((Cast) expression).child(),
inferType);
+ // }
+ return validForInfer(((Cast) expression).child(), inferType);
+ }
+ } else if (inferType == InferType.DATE) {
+ if (expression instanceof Cast) {
+ DataType dataType = expression.getDataType();
+ DataType childType = ((Cast) expression).child().getDataType();
+ // avoid lost precision
+ if (dataType instanceof DateType) {
+ if (childType instanceof DateV2Type || childType
instanceof DateType) {
+ return validForInfer(((Cast) expression).child(),
inferType);
+ }
+ } else if (dataType instanceof DateV2Type) {
+ if (childType instanceof DateType || childType instanceof
DateV2Type) {
+ return validForInfer(((Cast) expression).child(),
inferType);
+ }
+ } else if (dataType instanceof DateTimeType) {
+ if (!(childType instanceof DateTimeV2Type)) {
+ return validForInfer(((Cast) expression).child(),
inferType);
}
+ } else if (dataType instanceof DateTimeV2Type) {
+ return validForInfer(((Cast) expression).child(),
inferType);
}
- return false;
}
-
- private Expression replaceSlot(Expression expr, DataType
originDataType) {
- return expr.rewriteUp(e -> {
- if (isDataTypeValid(originDataType,
leftSlotEqualToRightSlot)) {
- if (ExpressionUtils.isTwoExpressionEqualWithCast(e,
leftSlotEqualToRightSlot.child(0))) {
- return leftSlotEqualToRightSlot.child(1);
- } else if
(ExpressionUtils.isTwoExpressionEqualWithCast(e,
leftSlotEqualToRightSlot.child(1))) {
- return leftSlotEqualToRightSlot.child(0);
- }
- }
- return e;
- });
+ } else if (inferType == InferType.STRING) {
+ if (expression instanceof Cast) {
+ DataType dataType = expression.getDataType();
+ DataType childType = ((Cast) expression).child().getDataType();
+ // avoid substring cast such as cast(char(3) as char(2))
+ if (dataType.width() <= 0 || (dataType.width() >=
childType.width() && childType.width() >= 0)) {
+ return validForInfer(((Cast) expression).child(),
inferType);
+ }
}
- }, null);
+ } else {
+ return Optional.empty();
+ }
+ return Optional.empty();
+ }
+
+ private ComparisonInferInfo inferInferInfo(ComparisonPredicate
comparisonPredicate) {
+ DataType leftType = comparisonPredicate.left().getDataType();
+ InferType inferType;
+ if (leftType instanceof CharacterType) {
+ inferType = InferType.STRING;
+ } else if (leftType instanceof IntegralType) {
+ inferType = InferType.INTEGRAL;
+ } else if (leftType instanceof DateLikeType) {
+ inferType = InferType.DATE;
+ } else {
+ inferType = InferType.OTHER;
+ }
+ Optional<Expression> left = validForInfer(comparisonPredicate.left(),
inferType);
+ Optional<Expression> right =
validForInfer(comparisonPredicate.right(), inferType);
+ if (!left.isPresent() || !right.isPresent()) {
+ inferType = InferType.NONE;
+ }
+ return new ComparisonInferInfo(inferType, left, right,
comparisonPredicate);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
*/
- private boolean canEquivalentInfer(Expression predicate) {
- return predicate instanceof EqualTo
- && predicate.children().stream().allMatch(e ->
- (e instanceof SlotReference) || (e instanceof Cast &&
e.child(0).isSlot()))
- &&
predicate.child(0).getDataType().equals(predicate.child(1).getDataType());
+ private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate
predicate) {
+ if (!(predicate instanceof EqualTo)) {
+ return new ComparisonInferInfo(InferType.NONE,
+ Optional.of(predicate.left()),
Optional.of(predicate.right()), predicate);
+ }
+ ComparisonInferInfo info = inferInferInfo(predicate);
+ if (info.inferType == InferType.NONE) {
+ return info;
+ }
+ if (info.left.get() instanceof SlotReference && info.right.get()
instanceof SlotReference) {
+ return info;
+ }
+ return new ComparisonInferInfo(InferType.NONE, info.left, info.right,
info.comparisonPredicate);
}
-
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index deb6eb983dd..8ddbd97d62b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -39,7 +39,6 @@ import
org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
-import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
@@ -253,34 +252,6 @@ public class ExpressionUtils {
}
}
- /**
- * get slot covered by cast
- * example: input: cast(cast(table.columnA)) output: columnA.datatype
- *
- */
- public static DataType getDatatypeCoveredByCast(Expression expr) {
- if (expr instanceof Cast) {
- return getDatatypeCoveredByCast(((Cast) expr).child());
- }
- return expr.getDataType();
- }
-
- /**
- * judge if expression is slot covered by cast
- * example: cast(cast(table.columnA))
- */
- public static boolean isExpressionSlotCoveredByCast(Expression expr) {
- if (expr instanceof Cast) {
- return isExpressionSlotCoveredByCast(((Cast) expr).child());
- }
- return expr instanceof SlotReference;
- }
-
- public static boolean isTwoExpressionEqualWithCast(Expression left,
Expression right) {
- return ExpressionUtils.extractSlotOrCastOnSlot(left)
- .equals(ExpressionUtils.extractSlotOrCastOnSlot(right));
- }
-
/**
* Replace expression node in the expression tree by `replaceMap` in
top-down manner.
* For example.
diff --git
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index a1621f1c239..c5942680ea7 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -41,7 +41,7 @@ suite("test_infer_predicate") {
explain {
sql "select * from infer_tb1 inner join infer_tb2 where
cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
- notContains "PREDICATES: k2"
+ contains "PREDICATES: k2"
}
explain {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]