This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 7a0734dbd60 [feature](Nereids): InferPredicates support In (#29458)
7a0734dbd60 is described below
commit 7a0734dbd60effa676d87bf5a5b7ca516e134d52
Author: jakevin <[email protected]>
AuthorDate: Fri Jan 5 21:25:30 2024 +0800
[feature](Nereids): InferPredicates support In (#29458)
---
.../nereids/rules/rewrite/InferPredicates.java | 11 +-
.../rules/rewrite/PredicatePropagation.java | 177 +++++++++++++--------
.../nereids/rules/rewrite/PullUpPredicates.java | 4 +-
.../doris/nereids/trees/expressions/EqualTo.java | 4 -
.../nereids/trees/expressions/InPredicate.java | 11 ++
.../nereids/rules/rewrite/InferPredicatesTest.java | 62 ++++----
.../rules/rewrite/PredicatePropagationTest.java | 51 ++++++
.../data/nereids_p0/hint/fix_leading.out | 2 +-
.../data/nereids_p0/hint/test_leading.out | 12 +-
.../infer_predicate/infer_predicate.groovy | 2 +-
10 files changed, 220 insertions(+), 116 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
index 3c4593df54c..36236c3db8d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
@@ -27,7 +27,6 @@ import
org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
@@ -37,6 +36,7 @@ import java.util.stream.Collectors;
/**
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
+ * <pre>
* The logic is as follows:
* 1. poll up bottom predicate then infer additional predicates
* for example:
@@ -49,9 +49,9 @@ import java.util.stream.Collectors;
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id =
t2.id and t2.id = 1
* 2. put these predicates into `otherJoinConjuncts` , these predicates are
processed in the next
* round of predicate push-down
+ * </pre>
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext>
implements CustomRewriter {
- private final PredicatePropagation propagation = new
PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
@Override
@@ -62,6 +62,9 @@ public class InferPredicates extends
DefaultPlanRewriter<JobContext> implements
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan>
join, JobContext context) {
join = visitChildren(this, join, context);
+ if (join.isMarkJoin()) {
+ return join;
+ }
Plan left = join.left();
Plan right = join.right();
Set<Expression> expressions = getAllExpressions(left, right,
join.getOnClauseCondition());
@@ -86,7 +89,7 @@ public class InferPredicates extends
DefaultPlanRewriter<JobContext> implements
break;
}
if (left != join.left() || right != join.right()) {
- return join.withChildren(ImmutableList.of(left, right));
+ return join.withChildren(left, right);
} else {
return join;
}
@@ -109,7 +112,7 @@ public class InferPredicates extends
DefaultPlanRewriter<JobContext> implements
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on ->
baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
- baseExpressions.addAll(propagation.infer(baseExpressions));
+ baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
return baseExpressions;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index 72e9023dc45..7788bbb7f06 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
@@ -55,8 +56,7 @@ public class PredicatePropagation {
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
- OTHER(DataType.class)
- ;
+ OTHER(DataType.class);
private final Class<? extends DataType> superClazz;
@@ -65,15 +65,15 @@ public class PredicatePropagation {
}
}
- private class ComparisonInferInfo {
+ private static class EqualInferInfo {
public final InferType inferType;
- public final Optional<Expression> left;
- public final Optional<Expression> right;
+ public final Expression left;
+ public final Expression right;
public final ComparisonPredicate comparisonPredicate;
- public ComparisonInferInfo(InferType inferType,
- Optional<Expression> left, Optional<Expression> right,
+ public EqualInferInfo(InferType inferType,
+ Expression left, Expression right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
@@ -85,26 +85,27 @@ public class PredicatePropagation {
/**
* infer additional predicates.
*/
- public Set<Expression> infer(Set<Expression> predicates) {
+ public static Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
// if we support more infer predicate expression type, we should
impl withInferred() method.
// And should add inferred props in withChildren() method just
like ComparisonPredicate,
// and it's subclass, to mark the predicate is from infer.
- if (!(predicate instanceof ComparisonPredicate)) {
+ if (!(predicate instanceof ComparisonPredicate
+ || (predicate instanceof InPredicate && ((InPredicate)
predicate).isLiteralChildren()))) {
continue;
}
- ComparisonInferInfo equalInfo =
getEquivalentInferInfo((ComparisonPredicate) predicate);
+ if (predicate instanceof InPredicate) {
+ continue;
+ }
+ EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate)
predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
- .filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
- .map(ComparisonPredicate.class::cast)
- .map(this::inferInferInfo)
- .filter(predicateInfo -> predicateInfo.inferType !=
InferType.NONE)
- .map(predicateInfo -> doInfer(equalInfo, predicateInfo))
+ .filter(p -> p instanceof ComparisonPredicate || p
instanceof InPredicate)
+ .map(predicateInfo -> doInferPredicate(equalInfo,
predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
@@ -113,17 +114,64 @@ public class PredicatePropagation {
return inferred;
}
+ private static Expression doInferPredicate(EqualInferInfo equalInfo,
Expression predicate) {
+ Expression equalLeft = equalInfo.left;
+ Expression equalRight = equalInfo.right;
+
+ DataType leftType = predicate.child(0).getDataType();
+ InferType inferType;
+ if (leftType instanceof CharacterType) {
+ inferType = InferType.STRING;
+ } else if (leftType instanceof IntegralType) {
+ inferType = InferType.INTEGRAL;
+ } else if (leftType instanceof DateLikeType) {
+ inferType = InferType.DATE;
+ } else {
+ inferType = InferType.OTHER;
+ }
+ if (predicate instanceof ComparisonPredicate) {
+ ComparisonPredicate comparisonPredicate = (ComparisonPredicate)
predicate;
+ Optional<Expression> left =
validForInfer(comparisonPredicate.left(), inferType);
+ Optional<Expression> right =
validForInfer(comparisonPredicate.right(), inferType);
+ if (!left.isPresent() || !right.isPresent()) {
+ return null;
+ }
+ } else if (predicate instanceof InPredicate) {
+ InPredicate inPredicate = (InPredicate) predicate;
+ Optional<Expression> left =
validForInfer(inPredicate.getCompareExpr(), inferType);
+ if (!left.isPresent()) {
+ return null;
+ }
+ }
+
+ Expression newPredicate = predicate.rewriteUp(e -> {
+ if (e.equals(equalLeft)) {
+ return equalRight;
+ } else if (e.equals(equalRight)) {
+ return equalLeft;
+ } else {
+ return e;
+ }
+ });
+ if (predicate instanceof ComparisonPredicate) {
+ return
TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
newPredicate).withInferred(true);
+ } else {
+ return TypeCoercionUtils.processInPredicate((InPredicate)
newPredicate).withInferred(true);
+ }
+ }
+
/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace
the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition
for replacement
* eg: Satisfy `expression` is non-deterministic
*/
- private Expression doInfer(ComparisonInferInfo equalInfo,
ComparisonInferInfo predicateInfo) {
- Expression predicateLeft = predicateInfo.left.get();
- Expression predicateRight = predicateInfo.right.get();
- Expression equalLeft = equalInfo.left.get();
- Expression equalRight = equalInfo.right.get();
+ private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo
predicateInfo) {
+ Expression equalLeft = equalInfo.left;
+ Expression equalRight = equalInfo.right;
+
+ Expression predicateLeft = predicateInfo.left;
+ Expression predicateRight = predicateInfo.right;
Expression newLeft = inferOneSide(predicateLeft, equalLeft,
equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft,
equalRight);
if (newLeft == null || newRight == null) {
@@ -136,7 +184,7 @@ public class PredicatePropagation {
return DateFunctionRewrite.INSTANCE.rewrite(expr,
null).withInferred(true);
}
- private Expression inferOneSide(Expression predicateOneSide, Expression
equalLeft, Expression equalRight) {
+ private static Expression inferOneSide(Expression predicateOneSide,
Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
@@ -153,60 +201,55 @@ public class PredicatePropagation {
return null;
}
- private Optional<Expression> validForInfer(Expression expression,
InferType inferType) {
+ private static Optional<Expression> validForInfer(Expression expression,
InferType inferType) {
if
(!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
+ if (!(expression instanceof Cast)) {
+ return Optional.empty();
+ }
+ Cast cast = (Cast) expression;
+ Expression child = cast.child();
+ DataType dataType = cast.getDataType();
+ DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
- if (expression instanceof Cast) {
- // avoid cast from wider type to narrower type, such as
cast(int as smallint)
- // IntegralType dataType = (IntegralType)
expression.getDataType();
- // DataType childType = ((Cast)
expression).child().getDataType();
- // if (childType instanceof IntegralType &&
dataType.widerThan((IntegralType) childType)) {
- // return validForInfer(((Cast) expression).child(),
inferType);
- // }
- return validForInfer(((Cast) expression).child(), inferType);
- }
+ // avoid cast from wider type to narrower type, such as cast(int
as smallint)
+ // IntegralType dataType = (IntegralType) expression.getDataType();
+ // DataType childType = ((Cast) expression).child().getDataType();
+ // if (childType instanceof IntegralType &&
dataType.widerThan((IntegralType) childType)) {
+ // return validForInfer(((Cast) expression).child(),
inferType);
+ // }
+ return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
- if (expression instanceof Cast) {
- DataType dataType = expression.getDataType();
- DataType childType = ((Cast) expression).child().getDataType();
- // avoid lost precision
- if (dataType instanceof DateType) {
- if (childType instanceof DateV2Type || childType
instanceof DateType) {
- return validForInfer(((Cast) expression).child(),
inferType);
- }
- } else if (dataType instanceof DateV2Type) {
- if (childType instanceof DateType || childType instanceof
DateV2Type) {
- return validForInfer(((Cast) expression).child(),
inferType);
- }
- } else if (dataType instanceof DateTimeType) {
- if (!(childType instanceof DateTimeV2Type)) {
- return validForInfer(((Cast) expression).child(),
inferType);
- }
- } else if (dataType instanceof DateTimeV2Type) {
- return validForInfer(((Cast) expression).child(),
inferType);
+ // avoid lost precision
+ if (dataType instanceof DateType) {
+ if (childType instanceof DateV2Type || childType instanceof
DateType) {
+ return validForInfer(child, inferType);
+ }
+ } else if (dataType instanceof DateV2Type) {
+ if (childType instanceof DateType || childType instanceof
DateV2Type) {
+ return validForInfer(child, inferType);
}
+ } else if (dataType instanceof DateTimeType) {
+ if (!(childType instanceof DateTimeV2Type)) {
+ return validForInfer(child, inferType);
+ }
+ } else if (dataType instanceof DateTimeV2Type) {
+ return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
- if (expression instanceof Cast) {
- DataType dataType = expression.getDataType();
- DataType childType = ((Cast) expression).child().getDataType();
- // avoid substring cast such as cast(char(3) as char(2))
- if (dataType.width() <= 0 || (dataType.width() >=
childType.width() && childType.width() >= 0)) {
- return validForInfer(((Cast) expression).child(),
inferType);
- }
+ // avoid substring cast such as cast(char(3) as char(2))
+ if (dataType.width() <= 0 || (dataType.width() >=
childType.width() && childType.width() >= 0)) {
+ return validForInfer(child, inferType);
}
- } else {
- return Optional.empty();
}
return Optional.empty();
}
- private ComparisonInferInfo inferInferInfo(ComparisonPredicate
comparisonPredicate) {
+ private static EqualInferInfo inferInferInfo(ComparisonPredicate
comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
@@ -223,25 +266,27 @@ public class PredicatePropagation {
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
- return new ComparisonInferInfo(inferType, left, right,
comparisonPredicate);
+ return new EqualInferInfo(inferType,
left.orElse(comparisonPredicate.left()),
+ right.orElse(comparisonPredicate.right()),
comparisonPredicate);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
+ * <p>
+ * TODO: NullSafeEqual
*/
- private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate
predicate) {
+ private static EqualInferInfo getEqualInferInfo(ComparisonPredicate
predicate) {
if (!(predicate instanceof EqualTo)) {
- return new ComparisonInferInfo(InferType.NONE,
- Optional.of(predicate.left()),
Optional.of(predicate.right()), predicate);
+ return new EqualInferInfo(InferType.NONE, predicate.left(),
predicate.right(), predicate);
}
- ComparisonInferInfo info = inferInferInfo(predicate);
+ EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
- if (info.left.get() instanceof SlotReference && info.right.get()
instanceof SlotReference) {
+ if (info.left instanceof SlotReference && info.right instanceof
SlotReference) {
return info;
}
- return new ComparisonInferInfo(InferType.NONE, info.left, info.right,
info.comparisonPredicate);
+ return new EqualInferInfo(InferType.NONE, info.left, info.right,
info.comparisonPredicate);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
index 1a198c76ea5..26e1358c2e5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
@@ -47,7 +47,6 @@ import java.util.stream.Collectors;
*/
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>,
Void> {
- PredicatePropagation propagation = new PredicatePropagation();
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
@Override
@@ -99,6 +98,7 @@ public class PullUpPredicates extends
PlanVisitor<ImmutableSet<Expression>, Void
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<?
extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet<Expression> childPredicates =
aggregate.child().accept(this, context);
+ // TODO
Map<Expression, Slot> expressionSlotMap =
aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
@@ -130,7 +130,7 @@ public class PullUpPredicates extends
PlanVisitor<ImmutableSet<Expression>, Void
private ImmutableSet<Expression>
getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
- expressions.addAll(propagation.infer(expressions));
+ expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p ->
plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
index 2704d446555..3e71b3b89a0 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
@@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements
PropagateNullable {
super(ImmutableList.of(left, right), "=", inferred);
}
- private EqualTo(List<Expression> children) {
- this(children, false);
- }
-
private EqualTo(List<Expression> children, boolean inferred) {
super(children, "=", inferred);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
index d839a1e9062..c86a074dcfd 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
@@ -48,6 +48,12 @@ public class InPredicate extends Expression {
this.options = ImmutableList.copyOf(Objects.requireNonNull(options,
"In list cannot be null"));
}
+ public InPredicate(Expression compareExpr, List<Expression> options,
boolean inferred) {
+ super(new
Builder<Expression>().add(compareExpr).addAll(options).build(), inferred);
+ this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr
cannot be null");
+ this.options = ImmutableList.copyOf(Objects.requireNonNull(options,
"In list cannot be null"));
+ }
+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitInPredicate(this, context);
}
@@ -80,6 +86,11 @@ public class InPredicate extends Expression {
});
}
+ @Override
+ public Expression withInferred(boolean inferred) {
+ return new InPredicate(children.get(0),
ImmutableList.copyOf(children).subList(1, children.size()), true);
+ }
+
@Override
public String toString() {
return compareExpr + " IN " + options.stream()
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
index c910e98fcd5..0708ea3f172 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
@@ -25,7 +25,7 @@ import org.apache.doris.utframe.TestWithFeService;
import org.junit.jupiter.api.Test;
-public class InferPredicatesTest extends TestWithFeService implements
MemoPatternMatchSupported {
+class InferPredicatesTest extends TestWithFeService implements
MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@@ -77,7 +77,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest01() {
+ void inferPredicatesTest01() {
String sql = "select * from student join score on student.id =
score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -100,7 +100,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest02() {
+ void inferPredicatesTest02() {
String sql = "select * from student join score on student.id =
score.sid";
PlanChecker.from(connectContext)
@@ -117,7 +117,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest03() {
+ void inferPredicatesTest03() {
String sql = "select * from student join score on student.id =
score.sid where student.id in (1,2,3)";
PlanChecker.from(connectContext)
@@ -126,18 +126,17 @@ public class InferPredicatesTest extends
TestWithFeService implements MemoPatter
.matches(
logicalProject(
logicalJoin(
- logicalFilter(
- logicalOlapScan()
- ).when(filter ->
!ExpressionUtils.isInferred(filter.getPredicate())
+ logicalFilter(logicalOlapScan()).when(filter ->
!ExpressionUtils.isInferred(filter.getPredicate())
&
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
- logicalOlapScan()
+ logicalFilter(logicalOlapScan()).when(filter ->
ExpressionUtils.isInferred(filter.getPredicate())
+ &
filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
- public void inferPredicatesTest04() {
+ void inferPredicatesTest04() {
String sql = "select * from student join score on student.id =
score.sid and student.id in (1,2,3)";
PlanChecker.from(connectContext)
@@ -146,18 +145,17 @@ public class InferPredicatesTest extends
TestWithFeService implements MemoPatter
.matches(
logicalProject(
logicalJoin(
- logicalFilter(
- logicalOlapScan()
- ).when(filter ->
!ExpressionUtils.isInferred(filter.getPredicate())
+ logicalFilter(logicalOlapScan()).when(filter ->
!ExpressionUtils.isInferred(filter.getPredicate())
&
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
- logicalOlapScan()
+ logicalFilter(logicalOlapScan()).when(filter ->
ExpressionUtils.isInferred(filter.getPredicate())
+ &
filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
- public void inferPredicatesTest05() {
+ void inferPredicatesTest05() {
String sql = "select * from student join score on student.id =
score.sid join course on score.sid = course.id where student.id > 1";
PlanChecker.from(connectContext)
@@ -185,7 +183,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest06() {
+ void inferPredicatesTest06() {
String sql = "select * from student join score on student.id =
score.sid join course on score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext)
@@ -213,7 +211,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest07() {
+ void inferPredicatesTest07() {
String sql = "select * from student left join score on student.id =
score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -236,7 +234,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest08() {
+ void inferPredicatesTest08() {
String sql = "select * from student left join score on student.id =
score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -256,7 +254,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest09() {
+ void inferPredicatesTest09() {
// convert left join to inner join
String sql = "select * from student left join score on student.id =
score.sid where score.sid > 1";
@@ -280,7 +278,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest10() {
+ void inferPredicatesTest10() {
String sql = "select * from (select id as nid, name from student) t
left join score on t.nid = score.sid where t.nid > 1";
PlanChecker.from(connectContext)
@@ -305,7 +303,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest11() {
+ void inferPredicatesTest11() {
String sql = "select * from (select id as nid, name from student) t
left join score on t.nid = score.sid and t.nid > 1";
PlanChecker.from(connectContext)
@@ -327,7 +325,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest12() {
+ void inferPredicatesTest12() {
String sql = "select * from student left join (select sid as nid,
sum(grade) from score group by sid) s on s.nid = student.id where student.id >
1";
PlanChecker.from(connectContext)
@@ -356,7 +354,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest13() {
+ void inferPredicatesTest13() {
String sql = "select * from (select id, name from student where id =
1) t left join score on t.id = score.sid";
PlanChecker.from(connectContext)
@@ -381,7 +379,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest14() {
+ void inferPredicatesTest14() {
String sql = "select * from student left semi join score on student.id
= score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -406,7 +404,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest15() {
+ void inferPredicatesTest15() {
String sql = "select * from student left semi join score on student.id
= score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -431,7 +429,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest16() {
+ void inferPredicatesTest16() {
String sql = "select * from student left anti join score on student.id
= score.sid and student.id > 1";
PlanChecker.from(connectContext)
@@ -453,7 +451,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest17() {
+ void inferPredicatesTest17() {
String sql = "select * from student left anti join score on student.id
= score.sid and score.sid > 1";
PlanChecker.from(connectContext)
@@ -475,7 +473,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest18() {
+ void inferPredicatesTest18() {
String sql = "select * from student left anti join score on student.id
= score.sid where student.id > 1";
PlanChecker.from(connectContext)
@@ -500,7 +498,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest19() {
+ void inferPredicatesTest19() {
String sql = "select * from subquery1\n"
+ "left semi join (\n"
+ " select t1.k3\n"
@@ -564,7 +562,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest20() {
+ void inferPredicatesTest20() {
String sql = "select * from student left join score on student.id =
score.sid and score.sid > 1 inner join course on course.id = score.sid";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -592,7 +590,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
}
@Test
- public void inferPredicatesTest21() {
+ void inferPredicatesTest21() {
String sql = "select * from student,score,course where student.id =
score.sid and score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -623,7 +621,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
* test for #15310
*/
@Test
- public void inferPredicatesTest22() {
+ void inferPredicatesTest22() {
String sql = "select * from student join (select sid as id1, sid as
id2, grade from score) s on student.id = s.id1 where s.id1 > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@@ -651,7 +649,7 @@ public class InferPredicatesTest extends TestWithFeService
implements MemoPatter
* in this case, filter on relation s1 should not contain s1.id = 1.
*/
@Test
- public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
+ void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
String sql = "select * from student s1"
+ " left join (select sid as id1, sid as id2, grade from
score) s2 on s1.id = s2.id1 and s1.id = 1"
+ " join (select sid as id1, sid as id2, grade from score) s3
on s1.id = s3.id1 where s1.id = 2";
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
new file mode 100644
index 00000000000..b1aa25df1b1
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.junit.jupiter.api.Test;
+
+import java.util.Set;
+
+class PredicatePropagationTest {
+ private final SlotReference a = new SlotReference("a",
SmallIntType.INSTANCE);
+ private final SlotReference b = new SlotReference("b",
BigIntType.INSTANCE);
+
+ @Test
+ void equal() {
+ Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new
EqualTo(a, Literal.of(1)));
+ Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
+ System.out.println(inferExprs);
+ }
+
+ @Test
+ void in() {
+ Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new
InPredicate(a, ImmutableList.of(Literal.of(1))));
+ Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
+ System.out.println(inferExprs);
+ }
+}
diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out
b/regression-test/data/nereids_p0/hint/fix_leading.out
index a3ca4f54110..58122945bb6 100644
--- a/regression-test/data/nereids_p0/hint/fix_leading.out
+++ b/regression-test/data/nereids_p0/hint/fix_leading.out
@@ -9,7 +9,7 @@ PhysicalResultSink
----------PhysicalDistribute[DistributionSpecHash]
------------PhysicalOlapScan[t2]
--------PhysicalDistribute[DistributionSpecHash]
-----------NestedLoopJoin[CROSS_JOIN]
+----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4)
------------PhysicalOlapScan[t3]
------------PhysicalDistribute[DistributionSpecReplicated]
--------------PhysicalOlapScan[t4]
diff --git a/regression-test/data/nereids_p0/hint/test_leading.out
b/regression-test/data/nereids_p0/hint/test_leading.out
index d1bd8f8bd28..fe3831a9fc4 100644
--- a/regression-test/data/nereids_p0/hint/test_leading.out
+++ b/regression-test/data/nereids_p0/hint/test_leading.out
@@ -2609,7 +2609,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2631,7 +2631,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2745,7 +2745,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2767,7 +2767,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2881,7 +2881,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecHash]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2903,7 +2903,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecHash]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
diff --git
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index c5942680ea7..55645ed8ea0 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -41,7 +41,7 @@ suite("test_infer_predicate") {
explain {
sql "select * from infer_tb1 inner join infer_tb2 where
cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
- contains "PREDICATES: k2"
+ contains "PREDICATES: CAST(k2"
}
explain {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]