This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 203f072cf2a  [fix](Nereids): NullSafeEqual should be in 
HashJoinCondition #27127 (#27232)
203f072cf2a is described below

commit 203f072cf2a2115135266368dc6fecb824e871e3
Author: jakevin <[email protected]>
AuthorDate: Sat Nov 18 23:57:01 2023 +0800

     [fix](Nereids): NullSafeEqual should be in HashJoinCondition #27127 
(#27232)
---
 .../glue/translator/PhysicalPlanTranslator.java    |  4 +-
 .../processor/post/RuntimeFilterGenerator.java     | 16 +++---
 .../nereids/rules/rewrite/EliminateOuterJoin.java  | 60 ++++++++++++++++++++++
 .../PushdownExpressionsInHashCondition.java        |  7 ++-
 .../mv/AbstractSelectMaterializedIndexRule.java    |  5 +-
 .../doris/nereids/stats/FilterEstimation.java      |  6 +--
 .../apache/doris/nereids/stats/JoinEstimation.java | 28 +++++-----
 .../nereids/trees/expressions/EqualPredicate.java  | 36 +++++++++++++
 .../doris/nereids/trees/expressions/EqualTo.java   |  4 +-
 .../nereids/trees/expressions/NullSafeEqual.java   | 11 +---
 .../trees/plans/physical/AbstractPhysicalJoin.java |  7 +++
 .../org/apache/doris/nereids/util/JoinUtils.java   | 49 ++++++------------
 12 files changed, 153 insertions(+), 80 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 4a413158b3a..a3ac99ba359 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -69,7 +69,7 @@ import org.apache.doris.nereids.stats.StatsErrorEstimator;
 import org.apache.doris.nereids.trees.UnaryNode;
 import org.apache.doris.nereids.trees.expressions.AggregateExpression;
 import org.apache.doris.nereids.trees.expressions.CTEId;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -1114,7 +1114,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         JoinType joinType = hashJoin.getJoinType();
 
         List<Expr> execEqConjuncts = hashJoin.getHashJoinConjuncts().stream()
-                .map(EqualTo.class::cast)
+                .map(EqualPredicate.class::cast)
                 .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, 
hashJoin.left().getOutputSet()))
                 .map(e -> ExpressionTranslator.translate(e, context))
                 .collect(Collectors.toList());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
index 0243326b106..de13b7adb70 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
@@ -281,11 +281,10 @@ public class RuntimeFilterGenerator extends 
PlanPostProcessor {
         List<TRuntimeFilterType> legalTypes = 
Arrays.stream(TRuntimeFilterType.values())
                 .filter(type -> (type.getValue() & 
ctx.getSessionVariable().getRuntimeFilterType()) > 0)
                 .collect(Collectors.toList());
-        // TODO: some complex situation cannot be handled now, see 
testPushDownThroughJoin.
-        //   we will support it in later version.
-        for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
+        List<EqualTo> hashJoinConjuncts = join.getEqualToConjuncts();
+        for (int i = 0; i < hashJoinConjuncts.size(); i++) {
             EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
-                    (EqualTo) join.getHashJoinConjuncts().get(i), 
join.left().getOutputSet()));
+                    hashJoinConjuncts.get(i), join.left().getOutputSet()));
             for (TRuntimeFilterType type : legalTypes) {
                 //bitmap rf is generated by nested loop join.
                 if (type == TRuntimeFilterType.BITMAP) {
@@ -525,7 +524,7 @@ public class RuntimeFilterGenerator extends 
PlanPostProcessor {
                         || !(join.getHashJoinConjuncts().get(0) instanceof 
EqualTo)) {
                     break;
                 } else {
-                    EqualTo equalTo = (EqualTo) 
join.getHashJoinConjuncts().get(0);
+                    EqualTo equalTo = (EqualTo) 
join.getEqualToConjuncts().get(0);
                     equalTos.add(equalTo);
                     equalCondToJoinMap.put(equalTo, join);
                 }
@@ -561,12 +560,11 @@ public class RuntimeFilterGenerator extends 
PlanPostProcessor {
                         // check further whether the join upper side can bring 
equal set, which
                         // indicating actually the same runtime filter build 
side
                         // see above case 2 for reference
-                        List<Expression> conditions = 
curJoin.getHashJoinConjuncts();
                         boolean inSameEqualSet = false;
-                        for (Expression e : conditions) {
+                        for (EqualTo e : curJoin.getEqualToConjuncts()) {
                             if (e instanceof EqualTo) {
-                                SlotReference oneSide = (SlotReference) 
((EqualTo) e).left();
-                                SlotReference anotherSide = (SlotReference) 
((EqualTo) e).right();
+                                SlotReference oneSide = (SlotReference) 
e.left();
+                                SlotReference anotherSide = (SlotReference) 
e.right();
                                 if (anotherSideSlotSet.contains(oneSide) && 
anotherSideSlotSet.contains(anotherSide)) {
                                     inSameEqualSet = true;
                                     break;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
index 83cc37ed0b3..c2dcafbee43 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
@@ -19,17 +19,23 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.util.JoinUtils;
 import org.apache.doris.nereids.util.TypeUtils;
 import org.apache.doris.nereids.util.Utils;
 
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.ImmutableSet.Builder;
+import com.google.common.collect.Sets;
 
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.Optional;
 import java.util.Set;
@@ -63,6 +69,45 @@ public class EliminateOuterJoin extends 
OneRewriteRuleFactory {
             }
 
             JoinType newJoinType = tryEliminateOuterJoin(join.getJoinType(), 
canFilterLeftNull, canFilterRightNull);
+            Set<Expression> conjuncts = Sets.newHashSet();
+            conjuncts.addAll(filter.getConjuncts());
+            boolean conjunctsChanged = false;
+            if (!notNullSlots.isEmpty()) {
+                for (Slot slot : notNullSlots) {
+                    Not isNotNull = new Not(new IsNull(slot));
+                    isNotNull.isGeneratedIsNotNull = true;
+                    conjunctsChanged |= conjuncts.add(isNotNull);
+                }
+            }
+            if (newJoinType.isInnerJoin()) {
+                /*
+                 * for example: (A left join B on A.a=B.b) join C on B.x=C.x
+                 * inner join condition B.x=C.x implies 'B.x is not null',
+                 * by which the left outer join could be eliminated. Finally, 
the join transformed to
+                 * (A join B on A.a=B.b) join C on B.x=C.x.
+                 * This elimination can be processed recursively.
+                 *
+                 * TODO: is_not_null can also be inferred from A < B and so on
+                 */
+                conjunctsChanged |= join.getHashJoinConjuncts().stream()
+                        .map(EqualPredicate.class::cast)
+                        .map(equalTo -> 
JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet()))
+                        .anyMatch(equalTo -> 
createIsNotNullIfNecessary(equalTo, conjuncts));
+
+                JoinUtils.JoinSlotCoverageChecker checker = new 
JoinUtils.JoinSlotCoverageChecker(
+                        join.left().getOutput(),
+                        join.right().getOutput());
+                conjunctsChanged |= join.getOtherJoinConjuncts().stream()
+                        .filter(EqualPredicate.class::isInstance)
+                        .filter(equalTo -> 
checker.isHashJoinCondition((EqualPredicate) equalTo))
+                        .map(equalTo -> 
JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) equalTo,
+                                join.left().getOutputSet()))
+                        .anyMatch(equalTo -> 
createIsNotNullIfNecessary(equalTo, conjuncts));
+            }
+            if (conjunctsChanged) {
+                return 
filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet()))
+                        .withChildren(join.withJoinType(newJoinType));
+            }
             return filter.withChildren(join.withJoinType(newJoinType));
         }).toRule(RuleType.ELIMINATE_OUTER_JOIN);
     }
@@ -85,4 +130,19 @@ public class EliminateOuterJoin extends 
OneRewriteRuleFactory {
         }
         return joinType;
     }
+
+    private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, 
Collection<Expression> container) {
+        boolean containerChanged = false;
+        if (swapedEqualTo.left().nullable()) {
+            Not not = new Not(new IsNull(swapedEqualTo.left()));
+            not.isGeneratedIsNotNull = true;
+            containerChanged |= container.add(not);
+        }
+        if (swapedEqualTo.right().nullable()) {
+            Not not = new Not(new IsNull(swapedEqualTo.right()));
+            not.isGeneratedIsNotNull = true;
+            containerChanged |= container.add(not);
+        }
+        return containerChanged;
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
index 05da591526c..df7acb4553c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
@@ -20,7 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -77,11 +77,10 @@ public class PushdownExpressionsInHashCondition extends 
OneRewriteRuleFactory {
         Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
         Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
         join.getHashJoinConjuncts().forEach(conjunct -> {
-            Preconditions.checkArgument(conjunct instanceof EqualTo);
+            Preconditions.checkArgument(conjunct instanceof EqualPredicate);
             // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the 
situation, but actually it
             // doesn't swap the two sides.
-            conjunct = JoinUtils.swapEqualToForChildrenOrder(
-                    (EqualTo) conjunct, join.left().getOutputSet());
+            conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) 
conjunct, join.left().getOutputSet());
             generateReplaceMapAndProjectExprs(conjunct.child(0), 
exprReplaceMap, leftProjectExprs);
             generateReplaceMapAndProjectExprs(conjunct.child(1), 
exprReplaceMap, rightProjectExprs);
         });
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
index 012dec4c91c..c1550cb5bd5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
@@ -27,13 +27,12 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.CaseWhen;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InPredicate;
 import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.WhenClause;
@@ -306,7 +305,7 @@ public abstract class AbstractSelectMaterializedIndexRule {
 
         @Override
         public PrefixIndexCheckResult 
visitComparisonPredicate(ComparisonPredicate cp, Map<ExprId, String> context) {
-            if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
+            if (cp instanceof EqualPredicate) {
                 return check(cp, context, PrefixIndexCheckResult::createEqual);
             } else {
                 return check(cp, context, 
PrefixIndexCheckResult::createNonEqual);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
index f06c9d1cc4f..a412ff375fd 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.TreeNode;
 import org.apache.doris.nereids.trees.expressions.And;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
@@ -33,7 +34,6 @@ import org.apache.doris.nereids.trees.expressions.LessThan;
 import org.apache.doris.nereids.trees.expressions.LessThanEqual;
 import org.apache.doris.nereids.trees.expressions.Like;
 import org.apache.doris.nereids.trees.expressions.Not;
-import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -210,7 +210,7 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
             return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
         }
 
-        if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
+        if (cp instanceof EqualPredicate) {
             return estimateEqualTo(cp, statsForLeft, statsForRight, context);
         } else {
             if (cp instanceof LessThan || cp instanceof LessThanEqual) {
@@ -255,7 +255,7 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
             ColumnStatistic statsForLeft, ColumnStatistic statsForRight) {
         Expression left = cp.left();
         Expression right = cp.right();
-        if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
+        if (cp instanceof EqualPredicate) {
             return estimateColumnEqualToColumn(left, statsForLeft, right, 
statsForRight, context);
         }
         if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
index 800886c177f..f9d25cab171 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
@@ -19,7 +19,7 @@ package org.apache.doris.nereids.stats;
 
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Cast;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.JoinType;
@@ -45,14 +45,14 @@ import java.util.stream.Collectors;
 public class JoinEstimation {
     private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3;
 
-    private static EqualTo normalizeHashJoinCondition(EqualTo equalTo, 
Statistics leftStats, Statistics rightStats) {
-        boolean changeOrder = equalTo.left().getInputSlots().stream().anyMatch(
-                slot -> rightStats.findColumnStatistics(slot) != null
-        );
+    private static EqualPredicate normalizeHashJoinCondition(EqualPredicate 
equal, Statistics leftStats,
+            Statistics rightStats) {
+        boolean changeOrder = equal.left().getInputSlots().stream()
+                .anyMatch(slot -> rightStats.findColumnStatistics(slot) != 
null);
         if (changeOrder) {
-            return new EqualTo(equalTo.right(), equalTo.left());
+            return equal.commute();
         } else {
-            return equalTo;
+            return equal;
         }
     }
 
@@ -81,18 +81,18 @@ public class JoinEstimation {
          * In order to avoid error propagation, for unTrustEquations, we only 
use the biggest selectivity.
          */
         List<Double> unTrustEqualRatio = Lists.newArrayList();
-        List<EqualTo> unTrustableCondition = Lists.newArrayList();
+        List<EqualPredicate> unTrustableCondition = Lists.newArrayList();
         boolean leftBigger = leftStats.getRowCount() > 
rightStats.getRowCount();
         double rightStatsRowCount = 
StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
         double leftStatsRowCount = 
StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
-        List<EqualTo> trustableConditions = 
join.getHashJoinConjuncts().stream()
-                .map(expression -> (EqualTo) expression)
+        List<EqualPredicate> trustableConditions = 
join.getHashJoinConjuncts().stream()
+                .map(expression -> (EqualPredicate) expression)
                 .filter(
                         expression -> {
                             // since ndv is not accurate, if ndv/rowcount < 
almostUniqueThreshold,
                             // this column is regarded as unique.
                             double almostUniqueThreshold = 0.9;
-                            EqualTo equal = 
normalizeHashJoinCondition(expression, leftStats, rightStats);
+                            EqualPredicate equal = 
normalizeHashJoinCondition(expression, leftStats, rightStats);
                             ColumnStatistic eqLeftColStats = 
ExpressionEstimation.estimate(equal.left(), leftStats);
                             ColumnStatistic eqRightColStats = 
ExpressionEstimation.estimate(equal.right(), rightStats);
                             boolean trustable = eqRightColStats.ndv / 
rightStatsRowCount > almostUniqueThreshold
@@ -189,7 +189,7 @@ public class JoinEstimation {
     }
 
     private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics 
leftStats,
-            Statistics rightStats, Join join, EqualTo equalTo) {
+            Statistics rightStats, Join join, EqualPredicate equalTo) {
         Expression eqLeft = equalTo.left();
         Expression eqRight = equalTo.right();
         ColumnStatistic probColStats = leftStats.findColumnStatistics(eqLeft);
@@ -246,7 +246,7 @@ public class JoinEstimation {
         double rowCount = Double.POSITIVE_INFINITY;
         for (Expression conjunct : join.getHashJoinConjuncts()) {
             double eqRowCount = 
estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats,
-                    join, (EqualTo) conjunct);
+                    join, (EqualPredicate) conjunct);
             if (rowCount > eqRowCount) {
                 rowCount = eqRowCount;
             }
@@ -321,7 +321,7 @@ public class JoinEstimation {
     private static Statistics 
updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) {
         Map<Expression, ColumnStatistic> updatedCols = new HashMap<>();
         for (Expression expr : join.getHashJoinConjuncts()) {
-            EqualTo equalTo = (EqualTo) expr;
+            EqualPredicate equalTo = (EqualPredicate) expr;
             ColumnStatistic leftColStats = 
ExpressionEstimation.estimate(equalTo.left(), innerStats);
             ColumnStatistic rightColStats = 
ExpressionEstimation.estimate(equalTo.right(), innerStats);
             double minNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java
new file mode 100644
index 00000000000..3f61bd3cf62
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions;
+
+import java.util.List;
+
+/**
+ * EqualPredicate
+ */
+public abstract class EqualPredicate extends ComparisonPredicate {
+
+    protected EqualPredicate(List<Expression> children, String symbol) {
+        super(children, symbol);
+    }
+
+    @Override
+    public EqualPredicate commute() {
+        return null;
+    }
+}
+
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
index 0fa23a57e0a..1e72a006057 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
@@ -29,7 +29,7 @@ import java.util.List;
 /**
  * Equal to expression: a = b.
  */
-public class EqualTo extends ComparisonPredicate implements PropagateNullable {
+public class EqualTo extends EqualPredicate implements PropagateNullable {
 
     public EqualTo(Expression left, Expression right) {
         super(ImmutableList.of(left, right), "=");
@@ -60,7 +60,7 @@ public class EqualTo extends ComparisonPredicate implements 
PropagateNullable {
     }
 
     @Override
-    public ComparisonPredicate commute() {
+    public EqualTo commute() {
         return new EqualTo(right(), left());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
index c2b63aebbd7..48d05364fa3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
@@ -29,13 +29,7 @@ import java.util.List;
  * Null safe equal expression: a <=> b.
  * Unlike normal equal to expression, null <=> null is true.
  */
-public class NullSafeEqual extends ComparisonPredicate implements 
AlwaysNotNullable {
-    /**
-     * Constructor of Null Safe Equal ComparisonPredicate.
-     *
-     * @param left  left child of Null Safe Equal
-     * @param right right child of Null Safe Equal
-     */
+public class NullSafeEqual extends EqualPredicate implements AlwaysNotNullable 
{
     public NullSafeEqual(Expression left, Expression right) {
         super(ImmutableList.of(left, right), "<=>");
     }
@@ -61,8 +55,7 @@ public class NullSafeEqual extends ComparisonPredicate 
implements AlwaysNotNulla
     }
 
     @Override
-    public ComparisonPredicate commute() {
+    public NullSafeEqual commute() {
         return new NullSafeEqual(right(), left());
     }
-
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java
index f67123522c3..a39634917aa 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.plans.physical;
 import org.apache.doris.nereids.memo.GroupExpression;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
 import org.apache.doris.nereids.trees.expressions.Slot;
@@ -41,6 +42,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.stream.Collectors;
 
 /**
  * Abstract class for all physical join node.
@@ -109,6 +111,11 @@ public abstract class AbstractPhysicalJoin<
         return hashJoinConjuncts;
     }
 
+    public List<EqualTo> getEqualToConjuncts() {
+        return 
hashJoinConjuncts.stream().filter(EqualTo.class::isInstance).map(EqualTo.class::cast)
+                .collect(Collectors.toList());
+    }
+
     public boolean isShouldTranslateOutput() {
         return shouldTranslateOutput;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
index eda7d2e6ad1..25f84c096c8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
@@ -24,7 +24,7 @@ import org.apache.doris.nereids.properties.DistributionSpec;
 import org.apache.doris.nereids.properties.DistributionSpecHash;
 import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
 import org.apache.doris.nereids.properties.DistributionSpecReplicated;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Not;
@@ -61,31 +61,18 @@ public class JoinUtils {
         return !(join.getJoinType().isRightJoin() || 
join.getJoinType().isFullOuterJoin());
     }
 
-    private static final class JoinSlotCoverageChecker {
+    /**
+     * JoinSlotCoverageChecker
+     */
+    public static final class JoinSlotCoverageChecker {
         Set<ExprId> leftExprIds;
         Set<ExprId> rightExprIds;
 
-        JoinSlotCoverageChecker(List<Slot> left, List<Slot> right) {
+        public JoinSlotCoverageChecker(List<Slot> left, List<Slot> right) {
             leftExprIds = 
left.stream().map(Slot::getExprId).collect(Collectors.toSet());
             rightExprIds = 
right.stream().map(Slot::getExprId).collect(Collectors.toSet());
         }
 
-        JoinSlotCoverageChecker(Set<ExprId> left, Set<ExprId> right) {
-            leftExprIds = left;
-            rightExprIds = right;
-        }
-
-        /**
-         * PushDownExpressionInHashConjuncts ensure the "slots" is only one 
slot.
-         */
-        boolean isCoveredByLeftSlots(ExprId slot) {
-            return leftExprIds.contains(slot);
-        }
-
-        boolean isCoveredByRightSlots(ExprId slot) {
-            return rightExprIds.contains(slot);
-        }
-
         /**
          * consider following cases:
          * 1# A=1 => not for hash table
@@ -94,25 +81,20 @@ public class JoinUtils {
          * 4# t1.a=t2.a or t1.b=t2.b not for hash table
          * 5# t1.a > 1 not for hash table
          *
-         * @param equalTo a conjunct in on clause condition
+         * @param equal a conjunct in on clause condition
          * @return true if the equal can be used as hash join condition
          */
-        boolean isHashJoinCondition(EqualTo equalTo) {
-            Set<Slot> equalLeft = 
equalTo.left().collect(Slot.class::isInstance);
-            if (equalLeft.isEmpty()) {
+        public boolean isHashJoinCondition(EqualPredicate equal) {
+            Set<ExprId> equalLeftExprIds = equal.left().getInputSlotExprIds();
+            if (equalLeftExprIds.isEmpty()) {
                 return false;
             }
 
-            Set<Slot> equalRight = 
equalTo.right().collect(Slot.class::isInstance);
-            if (equalRight.isEmpty()) {
+            Set<ExprId> equalRightExprIds = 
equal.right().getInputSlotExprIds();
+            if (equalRightExprIds.isEmpty()) {
                 return false;
             }
 
-            List<ExprId> equalLeftExprIds = equalLeft.stream()
-                    .map(Slot::getExprId).collect(Collectors.toList());
-
-            List<ExprId> equalRightExprIds = equalRight.stream()
-                    .map(Slot::getExprId).collect(Collectors.toList());
             return leftExprIds.containsAll(equalLeftExprIds) && 
rightExprIds.containsAll(equalRightExprIds)
                     || leftExprIds.containsAll(equalRightExprIds) && 
rightExprIds.containsAll(equalLeftExprIds);
         }
@@ -129,9 +111,8 @@ public class JoinUtils {
     public static Pair<List<Expression>, List<Expression>> 
extractExpressionForHashTable(List<Slot> leftSlots,
             List<Slot> rightSlots, List<Expression> onConditions) {
         JoinSlotCoverageChecker checker = new 
JoinSlotCoverageChecker(leftSlots, rightSlots);
-        Map<Boolean, List<Expression>> mapper = onConditions.stream()
-                .collect(Collectors.groupingBy(
-                        expr -> (expr instanceof EqualTo) && 
checker.isHashJoinCondition((EqualTo) expr)));
+        Map<Boolean, List<Expression>> mapper = 
onConditions.stream().collect(Collectors.groupingBy(
+                expr -> (expr instanceof EqualPredicate) && 
checker.isHashJoinCondition((EqualPredicate) expr)));
         return Pair.of(
                 mapper.getOrDefault(true, ImmutableList.of()),
                 mapper.getOrDefault(false, ImmutableList.of())
@@ -187,7 +168,7 @@ public class JoinUtils {
      * The left child of origin predicate is t2.id and the right child of 
origin predicate is t1.id.
      * In this situation, the children of predicate need to be swap => 
t1.id=t2.id.
      */
-    public static Expression swapEqualToForChildrenOrder(EqualTo equalTo, 
Set<Slot> leftOutput) {
+    public static EqualPredicate swapEqualToForChildrenOrder(EqualPredicate 
equalTo, Set<Slot> leftOutput) {
         if (leftOutput.containsAll(equalTo.left().getInputSlots())) {
             return equalTo;
         } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to