This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a0734dbd60 [feature](Nereids): InferPredicates support In (#29458)
7a0734dbd60 is described below

commit 7a0734dbd60effa676d87bf5a5b7ca516e134d52
Author: jakevin <[email protected]>
AuthorDate: Fri Jan 5 21:25:30 2024 +0800

    [feature](Nereids): InferPredicates support In (#29458)
---
 .../nereids/rules/rewrite/InferPredicates.java     |  11 +-
 .../rules/rewrite/PredicatePropagation.java        | 177 +++++++++++++--------
 .../nereids/rules/rewrite/PullUpPredicates.java    |   4 +-
 .../doris/nereids/trees/expressions/EqualTo.java   |   4 -
 .../nereids/trees/expressions/InPredicate.java     |  11 ++
 .../nereids/rules/rewrite/InferPredicatesTest.java |  62 ++++----
 .../rules/rewrite/PredicatePropagationTest.java    |  51 ++++++
 .../data/nereids_p0/hint/fix_leading.out           |   2 +-
 .../data/nereids_p0/hint/test_leading.out          |  12 +-
 .../infer_predicate/infer_predicate.groovy         |   2 +-
 10 files changed, 220 insertions(+), 116 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
index 3c4593df54c..36236c3db8d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
@@ -27,7 +27,6 @@ import 
org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
 import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.PlanUtils;
 
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 
@@ -37,6 +36,7 @@ import java.util.stream.Collectors;
 
 /**
  * infer additional predicates for `LogicalFilter` and `LogicalJoin`.
+ * <pre>
  * The logic is as follows:
  * 1. poll up bottom predicate then infer additional predicates
  *   for example:
@@ -49,9 +49,9 @@ import java.util.stream.Collectors;
  *      select * from (select * from t1 where t1.id = 1) t join t2 on t.id = 
t2.id and t2.id = 1
  * 2. put these predicates into `otherJoinConjuncts` , these predicates are 
processed in the next
  *   round of predicate push-down
+ * </pre>
  */
 public class InferPredicates extends DefaultPlanRewriter<JobContext> 
implements CustomRewriter {
-    private final PredicatePropagation propagation = new 
PredicatePropagation();
     private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
 
     @Override
@@ -62,6 +62,9 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
     @Override
     public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> 
join, JobContext context) {
         join = visitChildren(this, join, context);
+        if (join.isMarkJoin()) {
+            return join;
+        }
         Plan left = join.left();
         Plan right = join.right();
         Set<Expression> expressions = getAllExpressions(left, right, 
join.getOnClauseCondition());
@@ -86,7 +89,7 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
                 break;
         }
         if (left != join.left() || right != join.right()) {
-            return join.withChildren(ImmutableList.of(left, right));
+            return join.withChildren(left, right);
         } else {
             return join;
         }
@@ -109,7 +112,7 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
         Set<Expression> baseExpressions = pullUpPredicates(left);
         baseExpressions.addAll(pullUpPredicates(right));
         condition.ifPresent(on -> 
baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
-        baseExpressions.addAll(propagation.infer(baseExpressions));
+        baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
         return baseExpressions;
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index 72e9023dc45..7788bbb7f06 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.types.DataType;
@@ -55,8 +56,7 @@ public class PredicatePropagation {
         INTEGRAL(IntegralType.class),
         STRING(CharacterType.class),
         DATE(DateLikeType.class),
-        OTHER(DataType.class)
-        ;
+        OTHER(DataType.class);
 
         private final Class<? extends DataType> superClazz;
 
@@ -65,15 +65,15 @@ public class PredicatePropagation {
         }
     }
 
-    private class ComparisonInferInfo {
+    private static class EqualInferInfo {
 
         public final InferType inferType;
-        public final Optional<Expression> left;
-        public final Optional<Expression> right;
+        public final Expression left;
+        public final Expression right;
         public final ComparisonPredicate comparisonPredicate;
 
-        public ComparisonInferInfo(InferType inferType,
-                Optional<Expression> left, Optional<Expression> right,
+        public EqualInferInfo(InferType inferType,
+                Expression left, Expression right,
                 ComparisonPredicate comparisonPredicate) {
             this.inferType = inferType;
             this.left = left;
@@ -85,26 +85,27 @@ public class PredicatePropagation {
     /**
      * infer additional predicates.
      */
-    public Set<Expression> infer(Set<Expression> predicates) {
+    public static Set<Expression> infer(Set<Expression> predicates) {
         Set<Expression> inferred = Sets.newHashSet();
         for (Expression predicate : predicates) {
             // if we support more infer predicate expression type, we should 
impl withInferred() method.
             // And should add inferred props in withChildren() method just 
like ComparisonPredicate,
             // and it's subclass, to mark the predicate is from infer.
-            if (!(predicate instanceof ComparisonPredicate)) {
+            if (!(predicate instanceof ComparisonPredicate
+                    || (predicate instanceof InPredicate && ((InPredicate) 
predicate).isLiteralChildren()))) {
                 continue;
             }
-            ComparisonInferInfo equalInfo = 
getEquivalentInferInfo((ComparisonPredicate) predicate);
+            if (predicate instanceof InPredicate) {
+                continue;
+            }
+            EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) 
predicate);
             if (equalInfo.inferType == InferType.NONE) {
                 continue;
             }
             Set<Expression> newInferred = predicates.stream()
-                    .filter(ComparisonPredicate.class::isInstance)
                     .filter(p -> !p.equals(predicate))
-                    .map(ComparisonPredicate.class::cast)
-                    .map(this::inferInferInfo)
-                    .filter(predicateInfo -> predicateInfo.inferType != 
InferType.NONE)
-                    .map(predicateInfo -> doInfer(equalInfo, predicateInfo))
+                    .filter(p -> p instanceof ComparisonPredicate || p 
instanceof InPredicate)
+                    .map(predicateInfo -> doInferPredicate(equalInfo, 
predicateInfo))
                     .filter(Objects::nonNull)
                     .collect(Collectors.toSet());
             inferred.addAll(newInferred);
@@ -113,17 +114,64 @@ public class PredicatePropagation {
         return inferred;
     }
 
+    private static Expression doInferPredicate(EqualInferInfo equalInfo, 
Expression predicate) {
+        Expression equalLeft = equalInfo.left;
+        Expression equalRight = equalInfo.right;
+
+        DataType leftType = predicate.child(0).getDataType();
+        InferType inferType;
+        if (leftType instanceof CharacterType) {
+            inferType = InferType.STRING;
+        } else if (leftType instanceof IntegralType) {
+            inferType = InferType.INTEGRAL;
+        } else if (leftType instanceof DateLikeType) {
+            inferType = InferType.DATE;
+        } else {
+            inferType = InferType.OTHER;
+        }
+        if (predicate instanceof ComparisonPredicate) {
+            ComparisonPredicate comparisonPredicate = (ComparisonPredicate) 
predicate;
+            Optional<Expression> left = 
validForInfer(comparisonPredicate.left(), inferType);
+            Optional<Expression> right = 
validForInfer(comparisonPredicate.right(), inferType);
+            if (!left.isPresent() || !right.isPresent()) {
+                return null;
+            }
+        } else if (predicate instanceof InPredicate) {
+            InPredicate inPredicate = (InPredicate) predicate;
+            Optional<Expression> left = 
validForInfer(inPredicate.getCompareExpr(), inferType);
+            if (!left.isPresent()) {
+                return null;
+            }
+        }
+
+        Expression newPredicate = predicate.rewriteUp(e -> {
+            if (e.equals(equalLeft)) {
+                return equalRight;
+            } else if (e.equals(equalRight)) {
+                return equalLeft;
+            } else {
+                return e;
+            }
+        });
+        if (predicate instanceof ComparisonPredicate) {
+            return 
TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) 
newPredicate).withInferred(true);
+        } else {
+            return TypeCoercionUtils.processInPredicate((InPredicate) 
newPredicate).withInferred(true);
+        }
+    }
+
     /**
      * Use the left or right child of `leftSlotEqualToRightSlot` to replace 
the left or right child of `expression`
      * Now only support infer `ComparisonPredicate`.
      * TODO: We should determine whether `expression` satisfies the condition 
for replacement
      *       eg: Satisfy `expression` is non-deterministic
      */
-    private Expression doInfer(ComparisonInferInfo equalInfo, 
ComparisonInferInfo predicateInfo) {
-        Expression predicateLeft = predicateInfo.left.get();
-        Expression predicateRight = predicateInfo.right.get();
-        Expression equalLeft = equalInfo.left.get();
-        Expression equalRight = equalInfo.right.get();
+    private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo 
predicateInfo) {
+        Expression equalLeft = equalInfo.left;
+        Expression equalRight = equalInfo.right;
+
+        Expression predicateLeft = predicateInfo.left;
+        Expression predicateRight = predicateInfo.right;
         Expression newLeft = inferOneSide(predicateLeft, equalLeft, 
equalRight);
         Expression newRight = inferOneSide(predicateRight, equalLeft, 
equalRight);
         if (newLeft == null || newRight == null) {
@@ -136,7 +184,7 @@ public class PredicatePropagation {
         return DateFunctionRewrite.INSTANCE.rewrite(expr, 
null).withInferred(true);
     }
 
-    private Expression inferOneSide(Expression predicateOneSide, Expression 
equalLeft, Expression equalRight) {
+    private static Expression inferOneSide(Expression predicateOneSide, 
Expression equalLeft, Expression equalRight) {
         if (predicateOneSide instanceof SlotReference) {
             if (predicateOneSide.equals(equalLeft)) {
                 return equalRight;
@@ -153,60 +201,55 @@ public class PredicatePropagation {
         return null;
     }
 
-    private Optional<Expression> validForInfer(Expression expression, 
InferType inferType) {
+    private static Optional<Expression> validForInfer(Expression expression, 
InferType inferType) {
         if 
(!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
             return Optional.empty();
         }
         if (expression instanceof SlotReference || expression.isConstant()) {
             return Optional.of(expression);
         }
+        if (!(expression instanceof Cast)) {
+            return Optional.empty();
+        }
+        Cast cast = (Cast) expression;
+        Expression child = cast.child();
+        DataType dataType = cast.getDataType();
+        DataType childType = child.getDataType();
         if (inferType == InferType.INTEGRAL) {
-            if (expression instanceof Cast) {
-                // avoid cast from wider type to narrower type, such as 
cast(int as smallint)
-                // IntegralType dataType = (IntegralType) 
expression.getDataType();
-                // DataType childType = ((Cast) 
expression).child().getDataType();
-                // if (childType instanceof IntegralType && 
dataType.widerThan((IntegralType) childType)) {
-                //     return validForInfer(((Cast) expression).child(), 
inferType);
-                // }
-                return validForInfer(((Cast) expression).child(), inferType);
-            }
+            // avoid cast from wider type to narrower type, such as cast(int 
as smallint)
+            // IntegralType dataType = (IntegralType) expression.getDataType();
+            // DataType childType = ((Cast) expression).child().getDataType();
+            // if (childType instanceof IntegralType && 
dataType.widerThan((IntegralType) childType)) {
+            //     return validForInfer(((Cast) expression).child(), 
inferType);
+            // }
+            return validForInfer(child, inferType);
         } else if (inferType == InferType.DATE) {
-            if (expression instanceof Cast) {
-                DataType dataType = expression.getDataType();
-                DataType childType = ((Cast) expression).child().getDataType();
-                // avoid lost precision
-                if (dataType instanceof DateType) {
-                    if (childType instanceof DateV2Type || childType 
instanceof DateType) {
-                        return validForInfer(((Cast) expression).child(), 
inferType);
-                    }
-                } else if (dataType instanceof DateV2Type) {
-                    if (childType instanceof DateType || childType instanceof 
DateV2Type) {
-                        return validForInfer(((Cast) expression).child(), 
inferType);
-                    }
-                } else if (dataType instanceof DateTimeType) {
-                    if (!(childType instanceof DateTimeV2Type)) {
-                        return validForInfer(((Cast) expression).child(), 
inferType);
-                    }
-                } else if (dataType instanceof DateTimeV2Type) {
-                    return validForInfer(((Cast) expression).child(), 
inferType);
+            // avoid lost precision
+            if (dataType instanceof DateType) {
+                if (childType instanceof DateV2Type || childType instanceof 
DateType) {
+                    return validForInfer(child, inferType);
+                }
+            } else if (dataType instanceof DateV2Type) {
+                if (childType instanceof DateType || childType instanceof 
DateV2Type) {
+                    return validForInfer(child, inferType);
                 }
+            } else if (dataType instanceof DateTimeType) {
+                if (!(childType instanceof DateTimeV2Type)) {
+                    return validForInfer(child, inferType);
+                }
+            } else if (dataType instanceof DateTimeV2Type) {
+                return validForInfer(child, inferType);
             }
         } else if (inferType == InferType.STRING) {
-            if (expression instanceof Cast) {
-                DataType dataType = expression.getDataType();
-                DataType childType = ((Cast) expression).child().getDataType();
-                // avoid substring cast such as cast(char(3) as char(2))
-                if (dataType.width() <= 0 || (dataType.width() >= 
childType.width() && childType.width() >= 0)) {
-                    return validForInfer(((Cast) expression).child(), 
inferType);
-                }
+            // avoid substring cast such as cast(char(3) as char(2))
+            if (dataType.width() <= 0 || (dataType.width() >= 
childType.width() && childType.width() >= 0)) {
+                return validForInfer(child, inferType);
             }
-        } else {
-            return Optional.empty();
         }
         return Optional.empty();
     }
 
-    private ComparisonInferInfo inferInferInfo(ComparisonPredicate 
comparisonPredicate) {
+    private static EqualInferInfo inferInferInfo(ComparisonPredicate 
comparisonPredicate) {
         DataType leftType = comparisonPredicate.left().getDataType();
         InferType inferType;
         if (leftType instanceof CharacterType) {
@@ -223,25 +266,27 @@ public class PredicatePropagation {
         if (!left.isPresent() || !right.isPresent()) {
             inferType = InferType.NONE;
         }
-        return new ComparisonInferInfo(inferType, left, right, 
comparisonPredicate);
+        return new EqualInferInfo(inferType, 
left.orElse(comparisonPredicate.left()),
+                right.orElse(comparisonPredicate.right()), 
comparisonPredicate);
     }
 
     /**
      * Currently only equivalence derivation is supported
      * and requires that the left and right sides of an expression must be slot
+     * <p>
+     * TODO: NullSafeEqual
      */
-    private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate 
predicate) {
+    private static EqualInferInfo getEqualInferInfo(ComparisonPredicate 
predicate) {
         if (!(predicate instanceof EqualTo)) {
-            return new ComparisonInferInfo(InferType.NONE,
-                    Optional.of(predicate.left()), 
Optional.of(predicate.right()), predicate);
+            return new EqualInferInfo(InferType.NONE, predicate.left(), 
predicate.right(), predicate);
         }
-        ComparisonInferInfo info = inferInferInfo(predicate);
+        EqualInferInfo info = inferInferInfo(predicate);
         if (info.inferType == InferType.NONE) {
             return info;
         }
-        if (info.left.get() instanceof SlotReference && info.right.get() 
instanceof SlotReference) {
+        if (info.left instanceof SlotReference && info.right instanceof 
SlotReference) {
             return info;
         }
-        return new ComparisonInferInfo(InferType.NONE, info.left, info.right, 
info.comparisonPredicate);
+        return new EqualInferInfo(InferType.NONE, info.left, info.right, 
info.comparisonPredicate);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
index 1a198c76ea5..26e1358c2e5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
@@ -47,7 +47,6 @@ import java.util.stream.Collectors;
  */
 public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, 
Void> {
 
-    PredicatePropagation propagation = new PredicatePropagation();
     Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
 
     @Override
@@ -99,6 +98,7 @@ public class PullUpPredicates extends 
PlanVisitor<ImmutableSet<Expression>, Void
     public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? 
extends Plan> aggregate, Void context) {
         return cacheOrElse(aggregate, () -> {
             ImmutableSet<Expression> childPredicates = 
aggregate.child().accept(this, context);
+            // TODO
             Map<Expression, Slot> expressionSlotMap = 
aggregate.getOutputExpressions()
                     .stream()
                     .filter(this::hasAgg)
@@ -130,7 +130,7 @@ public class PullUpPredicates extends 
PlanVisitor<ImmutableSet<Expression>, Void
 
     private ImmutableSet<Expression> 
getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
         Set<Expression> expressions = Sets.newHashSet(predicates);
-        expressions.addAll(propagation.infer(expressions));
+        expressions.addAll(PredicatePropagation.infer(expressions));
         return expressions.stream()
                 .filter(p -> 
plan.getOutputSet().containsAll(p.getInputSlots()))
                 .collect(ImmutableSet.toImmutableSet());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
index 2704d446555..3e71b3b89a0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
@@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements 
PropagateNullable {
         super(ImmutableList.of(left, right), "=", inferred);
     }
 
-    private EqualTo(List<Expression> children) {
-        this(children, false);
-    }
-
     private EqualTo(List<Expression> children, boolean inferred) {
         super(children, "=", inferred);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
index d839a1e9062..c86a074dcfd 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
@@ -48,6 +48,12 @@ public class InPredicate extends Expression {
         this.options = ImmutableList.copyOf(Objects.requireNonNull(options, 
"In list cannot be null"));
     }
 
+    public InPredicate(Expression compareExpr, List<Expression> options, 
boolean inferred) {
+        super(new 
Builder<Expression>().add(compareExpr).addAll(options).build(), inferred);
+        this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr 
cannot be null");
+        this.options = ImmutableList.copyOf(Objects.requireNonNull(options, 
"In list cannot be null"));
+    }
+
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
         return visitor.visitInPredicate(this, context);
     }
@@ -80,6 +86,11 @@ public class InPredicate extends Expression {
         });
     }
 
+    @Override
+    public Expression withInferred(boolean inferred) {
+        return new InPredicate(children.get(0), 
ImmutableList.copyOf(children).subList(1, children.size()), true);
+    }
+
     @Override
     public String toString() {
         return compareExpr + " IN " + options.stream()
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
index c910e98fcd5..0708ea3f172 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
@@ -25,7 +25,7 @@ import org.apache.doris.utframe.TestWithFeService;
 
 import org.junit.jupiter.api.Test;
 
-public class InferPredicatesTest extends TestWithFeService implements 
MemoPatternMatchSupported {
+class InferPredicatesTest extends TestWithFeService implements 
MemoPatternMatchSupported {
 
     @Override
     protected void runBeforeAll() throws Exception {
@@ -77,7 +77,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest01() {
+    void inferPredicatesTest01() {
         String sql = "select * from student join score on student.id = 
score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -100,7 +100,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest02() {
+    void inferPredicatesTest02() {
         String sql = "select * from student join score on student.id = 
score.sid";
 
         PlanChecker.from(connectContext)
@@ -117,7 +117,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest03() {
+    void inferPredicatesTest03() {
         String sql = "select * from student join score on student.id = 
score.sid where student.id in (1,2,3)";
 
         PlanChecker.from(connectContext)
@@ -126,18 +126,17 @@ public class InferPredicatesTest extends 
TestWithFeService implements MemoPatter
                 .matches(
                     logicalProject(
                         logicalJoin(
-                            logicalFilter(
-                                    logicalOlapScan()
-                            ).when(filter -> 
!ExpressionUtils.isInferred(filter.getPredicate())
+                            logicalFilter(logicalOlapScan()).when(filter -> 
!ExpressionUtils.isInferred(filter.getPredicate())
                                     & 
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
-                            logicalOlapScan()
+                            logicalFilter(logicalOlapScan()).when(filter -> 
ExpressionUtils.isInferred(filter.getPredicate())
+                                    & 
filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
                         )
                     )
                 );
     }
 
     @Test
-    public void inferPredicatesTest04() {
+    void inferPredicatesTest04() {
         String sql = "select * from student join score on student.id = 
score.sid and student.id in (1,2,3)";
 
         PlanChecker.from(connectContext)
@@ -146,18 +145,17 @@ public class InferPredicatesTest extends 
TestWithFeService implements MemoPatter
                 .matches(
                     logicalProject(
                         logicalJoin(
-                            logicalFilter(
-                                    logicalOlapScan()
-                            ).when(filter -> 
!ExpressionUtils.isInferred(filter.getPredicate())
+                            logicalFilter(logicalOlapScan()).when(filter -> 
!ExpressionUtils.isInferred(filter.getPredicate())
                                     & 
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
-                            logicalOlapScan()
+                            logicalFilter(logicalOlapScan()).when(filter -> 
ExpressionUtils.isInferred(filter.getPredicate())
+                                    & 
filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
                         )
                     )
                 );
     }
 
     @Test
-    public void inferPredicatesTest05() {
+    void inferPredicatesTest05() {
         String sql = "select * from student join score on student.id = 
score.sid join course on score.sid = course.id where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -185,7 +183,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest06() {
+    void inferPredicatesTest06() {
         String sql = "select * from student join score on student.id = 
score.sid join course on score.sid = course.id and score.sid > 1";
 
         PlanChecker.from(connectContext)
@@ -213,7 +211,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest07() {
+    void inferPredicatesTest07() {
         String sql = "select * from student left join score on student.id = 
score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -236,7 +234,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest08() {
+    void inferPredicatesTest08() {
         String sql = "select * from student left join score on student.id = 
score.sid and student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -256,7 +254,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest09() {
+    void inferPredicatesTest09() {
         // convert left join to inner join
         String sql = "select * from student left join score on student.id = 
score.sid where score.sid > 1";
 
@@ -280,7 +278,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest10() {
+    void inferPredicatesTest10() {
         String sql = "select * from (select id as nid, name from student) t 
left join score on t.nid = score.sid where t.nid > 1";
 
         PlanChecker.from(connectContext)
@@ -305,7 +303,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest11() {
+    void inferPredicatesTest11() {
         String sql = "select * from (select id as nid, name from student) t 
left join score on t.nid = score.sid and t.nid > 1";
 
         PlanChecker.from(connectContext)
@@ -327,7 +325,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest12() {
+    void inferPredicatesTest12() {
         String sql = "select * from student left join (select sid as nid, 
sum(grade) from score group by sid) s on s.nid = student.id where student.id > 
1";
 
         PlanChecker.from(connectContext)
@@ -356,7 +354,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest13() {
+    void inferPredicatesTest13() {
         String sql = "select * from (select id, name from student where id = 
1) t left join score on t.id = score.sid";
 
         PlanChecker.from(connectContext)
@@ -381,7 +379,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest14() {
+    void inferPredicatesTest14() {
         String sql = "select * from student left semi join score on student.id 
= score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -406,7 +404,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest15() {
+    void inferPredicatesTest15() {
         String sql = "select * from student left semi join score on student.id 
= score.sid and student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -431,7 +429,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest16() {
+    void inferPredicatesTest16() {
         String sql = "select * from student left anti join score on student.id 
= score.sid and student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -453,7 +451,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest17() {
+    void inferPredicatesTest17() {
         String sql = "select * from student left anti join score on student.id 
= score.sid and score.sid > 1";
 
         PlanChecker.from(connectContext)
@@ -475,7 +473,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest18() {
+    void inferPredicatesTest18() {
         String sql = "select * from student left anti join score on student.id 
= score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -500,7 +498,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest19() {
+    void inferPredicatesTest19() {
         String sql = "select * from subquery1\n"
                 + "left semi join (\n"
                 + "  select t1.k3\n"
@@ -564,7 +562,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest20() {
+    void inferPredicatesTest20() {
         String sql = "select * from student left join score on student.id = 
score.sid and score.sid > 1 inner join course on course.id = score.sid";
         PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
         PlanChecker.from(connectContext)
@@ -592,7 +590,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest21() {
+    void inferPredicatesTest21() {
         String sql = "select * from student,score,course where student.id = 
score.sid and score.sid = course.id and score.sid > 1";
         PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
         PlanChecker.from(connectContext)
@@ -623,7 +621,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
      * test for #15310
      */
     @Test
-    public void inferPredicatesTest22() {
+    void inferPredicatesTest22() {
         String sql = "select * from student join (select sid as id1, sid as 
id2, grade from score) s on student.id = s.id1 where s.id1 > 1";
         PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
         PlanChecker.from(connectContext)
@@ -651,7 +649,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
      * in this case, filter on relation s1 should not contain s1.id = 1.
      */
     @Test
-    public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
+    void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
         String sql = "select * from student s1"
                 + " left join (select sid as id1, sid as id2, grade from 
score) s2 on s1.id = s2.id1 and s1.id = 1"
                 + " join (select sid as id1, sid as id2, grade from score) s3 
on s1.id = s3.id1 where s1.id = 2";
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
new file mode 100644
index 00000000000..b1aa25df1b1
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.junit.jupiter.api.Test;
+
+import java.util.Set;
+
+class PredicatePropagationTest {
+    private final SlotReference a = new SlotReference("a", 
SmallIntType.INSTANCE);
+    private final SlotReference b = new SlotReference("b", 
BigIntType.INSTANCE);
+
+    @Test
+    void equal() {
+        Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new 
EqualTo(a, Literal.of(1)));
+        Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
+        System.out.println(inferExprs);
+    }
+
+    @Test
+    void in() {
+        Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new 
InPredicate(a, ImmutableList.of(Literal.of(1))));
+        Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
+        System.out.println(inferExprs);
+    }
+}
diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out 
b/regression-test/data/nereids_p0/hint/fix_leading.out
index a3ca4f54110..58122945bb6 100644
--- a/regression-test/data/nereids_p0/hint/fix_leading.out
+++ b/regression-test/data/nereids_p0/hint/fix_leading.out
@@ -9,7 +9,7 @@ PhysicalResultSink
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------PhysicalOlapScan[t2]
 --------PhysicalDistribute[DistributionSpecHash]
-----------NestedLoopJoin[CROSS_JOIN]
+----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4)
 ------------PhysicalOlapScan[t3]
 ------------PhysicalDistribute[DistributionSpecReplicated]
 --------------PhysicalOlapScan[t4]
diff --git a/regression-test/data/nereids_p0/hint/test_leading.out 
b/regression-test/data/nereids_p0/hint/test_leading.out
index d1bd8f8bd28..fe3831a9fc4 100644
--- a/regression-test/data/nereids_p0/hint/test_leading.out
+++ b/regression-test/data/nereids_p0/hint/test_leading.out
@@ -2609,7 +2609,7 @@ PhysicalResultSink
 ------------PhysicalProject
 --------------PhysicalOlapScan[t2]
 ------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
 ----------------PhysicalProject
 ------------------PhysicalOlapScan[t1]
 ----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2631,7 +2631,7 @@ PhysicalResultSink
 ------------PhysicalProject
 --------------PhysicalOlapScan[t2]
 ------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
 ----------------PhysicalProject
 ------------------PhysicalOlapScan[t3]
 ----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2745,7 +2745,7 @@ PhysicalResultSink
 ------------PhysicalProject
 --------------PhysicalOlapScan[t2]
 ------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
 ----------------PhysicalProject
 ------------------PhysicalOlapScan[t1]
 ----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2767,7 +2767,7 @@ PhysicalResultSink
 ------------PhysicalProject
 --------------PhysicalOlapScan[t2]
 ------------PhysicalDistribute[DistributionSpecReplicated]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
 ----------------PhysicalProject
 ------------------PhysicalOlapScan[t3]
 ----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2881,7 +2881,7 @@ PhysicalResultSink
 ------------PhysicalProject
 --------------PhysicalOlapScan[t2]
 ------------PhysicalDistribute[DistributionSpecHash]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
 ----------------PhysicalProject
 ------------------PhysicalOlapScan[t1]
 ----------------PhysicalDistribute[DistributionSpecReplicated]
@@ -2903,7 +2903,7 @@ PhysicalResultSink
 ------------PhysicalProject
 --------------PhysicalOlapScan[t2]
 ------------PhysicalDistribute[DistributionSpecHash]
---------------NestedLoopJoin[CROSS_JOIN]
+--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
 ----------------PhysicalProject
 ------------------PhysicalOlapScan[t3]
 ----------------PhysicalDistribute[DistributionSpecReplicated]
diff --git 
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy 
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index c5942680ea7..55645ed8ea0 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -41,7 +41,7 @@ suite("test_infer_predicate") {
 
     explain {
         sql "select * from infer_tb1 inner join infer_tb2 where 
cast(infer_tb2.k4 as int) = infer_tb1.k2  and infer_tb2.k4 = 1;"
-        contains "PREDICATES: k2"
+        contains "PREDICATES: CAST(k2"
     }
 
     explain {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to