This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new dea40e7095c [fix](Nereids): NullSafeEqual should be in 
HashJoinCondition (#27127)
dea40e7095c is described below

commit dea40e7095c28f72b313c9071dfc0f4d6d4ecb5c
Author: jakevin <[email protected]>
AuthorDate: Tue Nov 21 19:08:14 2023 +0800

    [fix](Nereids): NullSafeEqual should be in HashJoinCondition (#27127)
    
    Originally, we just put `EqualTo` in `HashJoinCondition`, we also need to 
allow `NullSafeEqual`
---
 .../glue/translator/PhysicalPlanTranslator.java    |  4 +-
 .../processor/post/RuntimeFilterGenerator.java     | 14 ++---
 .../processor/post/RuntimeFilterPruner.java        |  2 +-
 .../nereids/rules/rewrite/EliminateOuterJoin.java  | 23 ++++-----
 .../PushdownExpressionsInHashCondition.java        |  7 ++-
 .../mv/AbstractSelectMaterializedIndexRule.java    |  5 +-
 .../doris/nereids/stats/FilterEstimation.java      |  6 +--
 .../apache/doris/nereids/stats/JoinEstimation.java | 28 +++++-----
 .../nereids/trees/expressions/EqualPredicate.java  | 36 +++++++++++++
 .../doris/nereids/trees/expressions/EqualTo.java   |  4 +-
 .../nereids/trees/expressions/NullSafeEqual.java   | 11 +---
 .../doris/nereids/trees/plans/algebra/Join.java    |  7 +++
 .../trees/plans/physical/PhysicalHashJoin.java     |  4 +-
 .../org/apache/doris/nereids/util/JoinUtils.java   | 42 ++++-----------
 .../data/nereids_p0/join/test_join_15.out          | 60 +++++++++++++++++++---
 .../suites/nereids_p0/join/test_join_15.groovy     |  6 ++-
 16 files changed, 160 insertions(+), 99 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 60623320128..d4a8849aee8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -69,7 +69,7 @@ import org.apache.doris.nereids.stats.StatsErrorEstimator;
 import org.apache.doris.nereids.trees.UnaryNode;
 import org.apache.doris.nereids.trees.expressions.AggregateExpression;
 import org.apache.doris.nereids.trees.expressions.CTEId;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -1138,7 +1138,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         JoinType joinType = hashJoin.getJoinType();
 
         List<Expr> execEqConjuncts = hashJoin.getHashJoinConjuncts().stream()
-                .map(EqualTo.class::cast)
+                .map(EqualPredicate.class::cast)
                 .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, 
hashJoin.left().getOutputSet()))
                 .map(e -> ExpressionTranslator.translate(e, context))
                 .collect(Collectors.toList());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
index 579fe7485ab..cff906df208 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java
@@ -363,9 +363,10 @@ public class RuntimeFilterGenerator extends 
PlanPostProcessor {
         List<TRuntimeFilterType> legalTypes = 
Arrays.stream(TRuntimeFilterType.values())
                 .filter(type -> (type.getValue() & 
ctx.getSessionVariable().getRuntimeFilterType()) > 0)
                 .collect(Collectors.toList());
-        for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
+        List<EqualTo> hashJoinConjuncts = join.getEqualToConjuncts();
+        for (int i = 0; i < hashJoinConjuncts.size(); i++) {
             EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
-                    (EqualTo) join.getHashJoinConjuncts().get(i), 
join.left().getOutputSet()));
+                    hashJoinConjuncts.get(i), join.left().getOutputSet()));
             for (TRuntimeFilterType type : legalTypes) {
                 //bitmap rf is generated by nested loop join.
                 if (type == TRuntimeFilterType.BITMAP) {
@@ -487,7 +488,7 @@ public class RuntimeFilterGenerator extends 
PlanPostProcessor {
                         || !(join.getHashJoinConjuncts().get(0) instanceof 
EqualTo)) {
                     break;
                 } else {
-                    EqualTo equalTo = (EqualTo) 
join.getHashJoinConjuncts().get(0);
+                    EqualTo equalTo = join.getEqualToConjuncts().get(0);
                     equalTos.add(equalTo);
                     equalCondToJoinMap.put(equalTo, join);
                 }
@@ -523,12 +524,11 @@ public class RuntimeFilterGenerator extends 
PlanPostProcessor {
                         // check further whether the join upper side can bring 
equal set, which
                         // indicating actually the same runtime filter build 
side
                         // see above case 2 for reference
-                        List<Expression> conditions = 
curJoin.getHashJoinConjuncts();
                         boolean inSameEqualSet = false;
-                        for (Expression e : conditions) {
+                        for (EqualTo e : curJoin.getEqualToConjuncts()) {
                             if (e instanceof EqualTo) {
-                                SlotReference oneSide = (SlotReference) 
((EqualTo) e).left();
-                                SlotReference anotherSide = (SlotReference) 
((EqualTo) e).right();
+                                SlotReference oneSide = (SlotReference) 
e.left();
+                                SlotReference anotherSide = (SlotReference) 
e.right();
                                 if (anotherSideSlotSet.contains(oneSide) && 
anotherSideSlotSet.contains(anotherSide)) {
                                     inSameEqualSet = true;
                                     break;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java
index e32d12edac3..b39bb8ec180 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java
@@ -85,7 +85,7 @@ public class RuntimeFilterPruner extends PlanPostProcessor {
             List<ExprId> exprIds = ctx.getTargetExprIdByFilterJoin(join);
             if (exprIds != null && !exprIds.isEmpty()) {
                 boolean isEffective = false;
-                for (Expression expr : join.getHashJoinConjuncts()) {
+                for (Expression expr : join.getEqualToConjuncts()) {
                     if (isEffectiveRuntimeFilter((EqualTo) expr, join)) {
                         isEffective = true;
                     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
index 440e5d73ae3..1afd6a175f3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.IsNull;
@@ -89,24 +90,20 @@ public class EliminateOuterJoin extends 
OneRewriteRuleFactory {
                  *
                  * TODO: is_not_null can also be inferred from A < B and so on
                  */
-                conjunctsChanged |= join.getHashJoinConjuncts().stream()
+                conjunctsChanged |= join.getEqualToConjuncts().stream()
                         .map(EqualTo.class::cast)
-                        .map(equalTo ->
-                                (EqualTo) 
JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet()))
-                        .map(equalTo -> createIsNotNullIfNecessary(equalTo, 
conjuncts)
-                        ).anyMatch(Boolean::booleanValue);
+                        .map(equalTo -> 
JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet()))
+                        .anyMatch(equalTo -> 
createIsNotNullIfNecessary(equalTo, conjuncts));
 
                 JoinUtils.JoinSlotCoverageChecker checker = new 
JoinUtils.JoinSlotCoverageChecker(
                         join.left().getOutput(),
                         join.right().getOutput());
-                conjunctsChanged |= 
join.getOtherJoinConjuncts().stream().filter(EqualTo.class::isInstance)
-                        .map(EqualTo.class::cast)
-                        .filter(equalTo -> 
checker.isHashJoinCondition(equalTo))
-                        .map(equalTo -> (EqualTo) 
JoinUtils.swapEqualToForChildrenOrder(equalTo,
+                conjunctsChanged |= join.getOtherJoinConjuncts().stream()
+                        .filter(EqualTo.class::isInstance)
+                        .filter(equalTo -> 
checker.isHashJoinCondition((EqualPredicate) equalTo))
+                        .map(equalTo -> 
JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) equalTo,
                                 join.left().getOutputSet()))
-                        .map(equalTo ->
-                            createIsNotNullIfNecessary(equalTo, conjuncts))
-                        .anyMatch(Boolean::booleanValue);
+                        .anyMatch(equalTo -> 
createIsNotNullIfNecessary(equalTo, conjuncts));
             }
             if (conjunctsChanged) {
                 return 
filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet()))
@@ -135,7 +132,7 @@ public class EliminateOuterJoin extends 
OneRewriteRuleFactory {
         return joinType;
     }
 
-    private boolean createIsNotNullIfNecessary(EqualTo swapedEqualTo, 
Collection<Expression> container) {
+    private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, 
Collection<Expression> container) {
         boolean containerChanged = false;
         if (swapedEqualTo.left().nullable()) {
             Not not = new Not(new IsNull(swapedEqualTo.left()));
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
index 05da591526c..df7acb4553c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java
@@ -20,7 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -77,11 +77,10 @@ public class PushdownExpressionsInHashCondition extends 
OneRewriteRuleFactory {
         Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
         Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
         join.getHashJoinConjuncts().forEach(conjunct -> {
-            Preconditions.checkArgument(conjunct instanceof EqualTo);
+            Preconditions.checkArgument(conjunct instanceof EqualPredicate);
             // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the 
situation, but actually it
             // doesn't swap the two sides.
-            conjunct = JoinUtils.swapEqualToForChildrenOrder(
-                    (EqualTo) conjunct, join.left().getOutputSet());
+            conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) 
conjunct, join.left().getOutputSet());
             generateReplaceMapAndProjectExprs(conjunct.child(0), 
exprReplaceMap, leftProjectExprs);
             generateReplaceMapAndProjectExprs(conjunct.child(1), 
exprReplaceMap, rightProjectExprs);
         });
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
index 012dec4c91c..c1550cb5bd5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
@@ -27,13 +27,12 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.CaseWhen;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InPredicate;
 import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.WhenClause;
@@ -306,7 +305,7 @@ public abstract class AbstractSelectMaterializedIndexRule {
 
         @Override
         public PrefixIndexCheckResult 
visitComparisonPredicate(ComparisonPredicate cp, Map<ExprId, String> context) {
-            if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
+            if (cp instanceof EqualPredicate) {
                 return check(cp, context, PrefixIndexCheckResult::createEqual);
             } else {
                 return check(cp, context, 
PrefixIndexCheckResult::createNonEqual);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
index 055f3b88a07..e2d7f40622f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.TreeNode;
 import org.apache.doris.nereids.trees.expressions.And;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
@@ -33,7 +34,6 @@ import org.apache.doris.nereids.trees.expressions.LessThan;
 import org.apache.doris.nereids.trees.expressions.LessThanEqual;
 import org.apache.doris.nereids.trees.expressions.Like;
 import org.apache.doris.nereids.trees.expressions.Not;
-import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -210,7 +210,7 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
             return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
         }
 
-        if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
+        if (cp instanceof EqualPredicate) {
             return estimateEqualTo(cp, statsForLeft, statsForRight, context);
         } else {
             if (cp instanceof LessThan || cp instanceof LessThanEqual) {
@@ -255,7 +255,7 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
             ColumnStatistic statsForLeft, ColumnStatistic statsForRight) {
         Expression left = cp.left();
         Expression right = cp.right();
-        if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
+        if (cp instanceof EqualPredicate) {
             return estimateColumnEqualToColumn(left, statsForLeft, right, 
statsForRight, context);
         }
         if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
index 3b7797439f3..d43171375b8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
@@ -19,7 +19,7 @@ package org.apache.doris.nereids.stats;
 
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Cast;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.JoinType;
@@ -45,14 +45,14 @@ import java.util.stream.Collectors;
 public class JoinEstimation {
     private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3;
 
-    private static EqualTo normalizeHashJoinCondition(EqualTo equalTo, 
Statistics leftStats, Statistics rightStats) {
-        boolean changeOrder = equalTo.left().getInputSlots().stream().anyMatch(
-                slot -> rightStats.findColumnStatistics(slot) != null
-        );
+    private static EqualPredicate normalizeHashJoinCondition(EqualPredicate 
equal, Statistics leftStats,
+            Statistics rightStats) {
+        boolean changeOrder = equal.left().getInputSlots().stream()
+                .anyMatch(slot -> rightStats.findColumnStatistics(slot) != 
null);
         if (changeOrder) {
-            return new EqualTo(equalTo.right(), equalTo.left());
+            return equal.commute();
         } else {
-            return equalTo;
+            return equal;
         }
     }
 
@@ -81,18 +81,18 @@ public class JoinEstimation {
          * In order to avoid error propagation, for unTrustEquations, we only 
use the biggest selectivity.
          */
         List<Double> unTrustEqualRatio = Lists.newArrayList();
-        List<EqualTo> unTrustableCondition = Lists.newArrayList();
+        List<EqualPredicate> unTrustableCondition = Lists.newArrayList();
         boolean leftBigger = leftStats.getRowCount() > 
rightStats.getRowCount();
         double rightStatsRowCount = 
StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
         double leftStatsRowCount = 
StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
-        List<EqualTo> trustableConditions = 
join.getHashJoinConjuncts().stream()
-                .map(expression -> (EqualTo) expression)
+        List<EqualPredicate> trustableConditions = 
join.getHashJoinConjuncts().stream()
+                .map(expression -> (EqualPredicate) expression)
                 .filter(
                         expression -> {
                             // since ndv is not accurate, if ndv/rowcount < 
almostUniqueThreshold,
                             // this column is regarded as unique.
                             double almostUniqueThreshold = 0.9;
-                            EqualTo equal = 
normalizeHashJoinCondition(expression, leftStats, rightStats);
+                            EqualPredicate equal = 
normalizeHashJoinCondition(expression, leftStats, rightStats);
                             ColumnStatistic eqLeftColStats = 
ExpressionEstimation.estimate(equal.left(), leftStats);
                             ColumnStatistic eqRightColStats = 
ExpressionEstimation.estimate(equal.right(), rightStats);
                             boolean trustable = eqRightColStats.ndv / 
rightStatsRowCount > almostUniqueThreshold
@@ -204,7 +204,7 @@ public class JoinEstimation {
     }
 
     private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics 
leftStats,
-            Statistics rightStats, Join join, EqualTo equalTo) {
+            Statistics rightStats, Join join, EqualPredicate equalTo) {
         Expression eqLeft = equalTo.left();
         Expression eqRight = equalTo.right();
         ColumnStatistic probColStats = leftStats.findColumnStatistics(eqLeft);
@@ -261,7 +261,7 @@ public class JoinEstimation {
         double rowCount = Double.POSITIVE_INFINITY;
         for (Expression conjunct : join.getHashJoinConjuncts()) {
             double eqRowCount = 
estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats,
-                    join, (EqualTo) conjunct);
+                    join, (EqualPredicate) conjunct);
             if (rowCount > eqRowCount) {
                 rowCount = eqRowCount;
             }
@@ -336,7 +336,7 @@ public class JoinEstimation {
     private static Statistics 
updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) {
         Map<Expression, ColumnStatistic> updatedCols = new HashMap<>();
         for (Expression expr : join.getHashJoinConjuncts()) {
-            EqualTo equalTo = (EqualTo) expr;
+            EqualPredicate equalTo = (EqualPredicate) expr;
             ColumnStatistic leftColStats = 
ExpressionEstimation.estimate(equalTo.left(), innerStats);
             ColumnStatistic rightColStats = 
ExpressionEstimation.estimate(equalTo.right(), innerStats);
             double minNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java
new file mode 100644
index 00000000000..3f61bd3cf62
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions;
+
+import java.util.List;
+
+/**
+ * EqualPredicate
+ */
+public abstract class EqualPredicate extends ComparisonPredicate {
+
+    protected EqualPredicate(List<Expression> children, String symbol) {
+        super(children, symbol);
+    }
+
+    @Override
+    public EqualPredicate commute() {
+        return null;
+    }
+}
+
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
index 065f6b93403..3faccff6d99 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java
@@ -29,7 +29,7 @@ import java.util.List;
 /**
  * Equal to expression: a = b.
  */
-public class EqualTo extends ComparisonPredicate implements PropagateNullable {
+public class EqualTo extends EqualPredicate implements PropagateNullable {
 
     public EqualTo(Expression left, Expression right) {
         super(ImmutableList.of(left, right), "=");
@@ -55,7 +55,7 @@ public class EqualTo extends ComparisonPredicate implements 
PropagateNullable {
     }
 
     @Override
-    public ComparisonPredicate commute() {
+    public EqualTo commute() {
         return new EqualTo(right(), left());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
index c2b63aebbd7..48d05364fa3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java
@@ -29,13 +29,7 @@ import java.util.List;
  * Null safe equal expression: a <=> b.
  * Unlike normal equal to expression, null <=> null is true.
  */
-public class NullSafeEqual extends ComparisonPredicate implements 
AlwaysNotNullable {
-    /**
-     * Constructor of Null Safe Equal ComparisonPredicate.
-     *
-     * @param left  left child of Null Safe Equal
-     * @param right right child of Null Safe Equal
-     */
+public class NullSafeEqual extends EqualPredicate implements AlwaysNotNullable 
{
     public NullSafeEqual(Expression left, Expression right) {
         super(ImmutableList.of(left, right), "<=>");
     }
@@ -61,8 +55,7 @@ public class NullSafeEqual extends ComparisonPredicate 
implements AlwaysNotNulla
     }
 
     @Override
-    public ComparisonPredicate commute() {
+    public NullSafeEqual commute() {
         return new NullSafeEqual(right(), left());
     }
-
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java
index 77bf6c9148d..3f96c4d11cc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.trees.plans.algebra;
 
+import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
 import org.apache.doris.nereids.trees.plans.JoinHint;
@@ -25,6 +26,7 @@ import org.apache.doris.nereids.trees.plans.JoinType;
 
 import java.util.List;
 import java.util.Optional;
+import java.util.stream.Collectors;
 
 /**
  * Common interface for logical/physical join.
@@ -34,6 +36,11 @@ public interface Join {
 
     List<Expression> getHashJoinConjuncts();
 
+    default List<EqualTo> getEqualToConjuncts() {
+        return 
getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance).map(EqualTo.class::cast)
+                .collect(Collectors.toList());
+    }
+
     List<Expression> getOtherJoinConjuncts();
 
     Optional<Expression> getOnClauseCondition();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java
index 0041812796a..183ccaabfa8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java
@@ -25,7 +25,7 @@ import 
org.apache.doris.nereids.processor.post.RuntimeFilterContext;
 import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
@@ -213,7 +213,7 @@ public class PhysicalHashJoin<
         if (ConnectContext.get() != null && 
ConnectContext.get().getSessionVariable().expandRuntimeFilterByInnerJoin) {
             if (!this.equals(builderNode) && this.getJoinType() == 
JoinType.INNER_JOIN) {
                 for (Expression expr : this.getHashJoinConjuncts()) {
-                    EqualTo equalTo = (EqualTo) expr;
+                    EqualPredicate equalTo = (EqualPredicate) expr;
                     if (probeExpr.equals(equalTo.left())) {
                         probExprList.add(equalTo.right());
                     } else if (probeExpr.equals(equalTo.right())) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
index bcf53ce29f8..862bf02e464 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
@@ -24,7 +24,7 @@ import org.apache.doris.nereids.properties.DistributionSpec;
 import org.apache.doris.nereids.properties.DistributionSpecHash;
 import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
 import org.apache.doris.nereids.properties.DistributionSpecReplicated;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Not;
@@ -88,22 +88,6 @@ public class JoinUtils {
             rightExprIds = 
right.stream().map(Slot::getExprId).collect(Collectors.toSet());
         }
 
-        JoinSlotCoverageChecker(Set<ExprId> left, Set<ExprId> right) {
-            leftExprIds = left;
-            rightExprIds = right;
-        }
-
-        /**
-         * PushDownExpressionInHashConjuncts ensure the "slots" is only one 
slot.
-         */
-        boolean isCoveredByLeftSlots(ExprId slot) {
-            return leftExprIds.contains(slot);
-        }
-
-        boolean isCoveredByRightSlots(ExprId slot) {
-            return rightExprIds.contains(slot);
-        }
-
         /**
          * consider following cases:
          * 1# A=1 => not for hash table
@@ -112,25 +96,20 @@ public class JoinUtils {
          * 4# t1.a=t2.a or t1.b=t2.b not for hash table
          * 5# t1.a > 1 not for hash table
          *
-         * @param equalTo a conjunct in on clause condition
+         * @param equal a conjunct in on clause condition
          * @return true if the equal can be used as hash join condition
          */
-        public boolean isHashJoinCondition(EqualTo equalTo) {
-            Set<Slot> equalLeft = equalTo.left().getInputSlots();
-            if (equalLeft.isEmpty()) {
+        public boolean isHashJoinCondition(EqualPredicate equal) {
+            Set<ExprId> equalLeftExprIds = equal.left().getInputSlotExprIds();
+            if (equalLeftExprIds.isEmpty()) {
                 return false;
             }
 
-            Set<Slot> equalRight = equalTo.right().getInputSlots();
-            if (equalRight.isEmpty()) {
+            Set<ExprId> equalRightExprIds = 
equal.right().getInputSlotExprIds();
+            if (equalRightExprIds.isEmpty()) {
                 return false;
             }
 
-            List<ExprId> equalLeftExprIds = equalLeft.stream()
-                    .map(Slot::getExprId).collect(Collectors.toList());
-
-            List<ExprId> equalRightExprIds = equalRight.stream()
-                    .map(Slot::getExprId).collect(Collectors.toList());
             return leftExprIds.containsAll(equalLeftExprIds) && 
rightExprIds.containsAll(equalRightExprIds)
                     || leftExprIds.containsAll(equalRightExprIds) && 
rightExprIds.containsAll(equalLeftExprIds);
         }
@@ -147,9 +126,8 @@ public class JoinUtils {
     public static Pair<List<Expression>, List<Expression>> 
extractExpressionForHashTable(List<Slot> leftSlots,
             List<Slot> rightSlots, List<Expression> onConditions) {
         JoinSlotCoverageChecker checker = new 
JoinSlotCoverageChecker(leftSlots, rightSlots);
-        Map<Boolean, List<Expression>> mapper = onConditions.stream()
-                .collect(Collectors.groupingBy(
-                        expr -> (expr instanceof EqualTo) && 
checker.isHashJoinCondition((EqualTo) expr)));
+        Map<Boolean, List<Expression>> mapper = 
onConditions.stream().collect(Collectors.groupingBy(
+                expr -> (expr instanceof EqualPredicate) && 
checker.isHashJoinCondition((EqualPredicate) expr)));
         return Pair.of(
                 mapper.getOrDefault(true, ImmutableList.of()),
                 mapper.getOrDefault(false, ImmutableList.of())
@@ -205,7 +183,7 @@ public class JoinUtils {
      * The left child of origin predicate is t2.id and the right child of 
origin predicate is t1.id.
      * In this situation, the children of predicate need to be swap => 
t1.id=t2.id.
      */
-    public static Expression swapEqualToForChildrenOrder(EqualTo equalTo, 
Set<Slot> leftOutput) {
+    public static EqualPredicate swapEqualToForChildrenOrder(EqualPredicate 
equalTo, Set<Slot> leftOutput) {
         if (leftOutput.containsAll(equalTo.left().getInputSlots())) {
             return equalTo;
         } else {
diff --git a/regression-test/data/nereids_p0/join/test_join_15.out 
b/regression-test/data/nereids_p0/join/test_join_15.out
index 5c9df35ba7b..e535253a28d 100644
--- a/regression-test/data/nereids_p0/join/test_join_15.out
+++ b/regression-test/data/nereids_p0/join/test_join_15.out
@@ -172,12 +172,20 @@ false     true    true    false   false
 3      \N      null    2019-09-09      \N      8.9     3       \N      null    
2019-09-09      \N      8.9
 5      \N      null    \N      2019-09-09T00:00        8.9     5       \N      
null    \N      2019-09-09T00:00        8.9
 
--- !hash_join --
+-- !hash_right_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      3       \N      null    
2019-09-09      \N      8.9
 \N     \N      \N      \N      \N      \N      5       \N      null    \N      
2019-09-09T00:00        8.9
 
+-- !hash_left_join --
+1      \N      null    \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+2      \N      2       \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+3      \N      null    2019-09-09      \N      8.9     \N      \N      \N      
\N      \N      \N
+5      \N      null    \N      2019-09-09T00:00        8.9     \N      \N      
\N      \N      \N      \N
+
+-- !hash_inner_join --
+
 -- !cross_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
@@ -226,12 +234,20 @@ false     true    true    false   false
 5      \N      null    \N      2019-09-09T00:00        8.9     3       \N      
null    2019-09-09      \N      8.9
 5      \N      null    \N      2019-09-09T00:00        8.9     5       \N      
null    \N      2019-09-09T00:00        8.9
 
--- !hash_join --
+-- !hash_right_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      3       \N      null    
2019-09-09      \N      8.9
 \N     \N      \N      \N      \N      \N      5       \N      null    \N      
2019-09-09T00:00        8.9
 
+-- !hash_left_join --
+1      \N      null    \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+2      \N      2       \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+3      \N      null    2019-09-09      \N      8.9     \N      \N      \N      
\N      \N      \N
+5      \N      null    \N      2019-09-09T00:00        8.9     \N      \N      
\N      \N      \N      \N
+
+-- !hash_inner_join --
+
 -- !cross_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
@@ -271,12 +287,20 @@ false     true    true    false   false
 5      \N      null    \N      2019-09-09T00:00        8.9     3       \N      
null    2019-09-09      \N      8.9
 5      \N      null    \N      2019-09-09T00:00        8.9     5       \N      
null    \N      2019-09-09T00:00        8.9
 
--- !hash_join --
+-- !hash_right_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      3       \N      null    
2019-09-09      \N      8.9
 \N     \N      \N      \N      \N      \N      5       \N      null    \N      
2019-09-09T00:00        8.9
 
+-- !hash_left_join --
+1      \N      null    \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+2      \N      2       \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+3      \N      null    2019-09-09      \N      8.9     \N      \N      \N      
\N      \N      \N
+5      \N      null    \N      2019-09-09T00:00        8.9     \N      \N      
\N      \N      \N      \N
+
+-- !hash_inner_join --
+
 -- !cross_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
@@ -314,12 +338,20 @@ false     true    true    false   false
 5      \N      null    \N      2019-09-09T00:00        8.9     2       \N      
2       \N      \N      8.9
 5      \N      null    \N      2019-09-09T00:00        8.9     5       \N      
null    \N      2019-09-09T00:00        8.9
 
--- !hash_join --
+-- !hash_right_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      3       \N      null    
2019-09-09      \N      8.9
 \N     \N      \N      \N      \N      \N      5       \N      null    \N      
2019-09-09T00:00        8.9
 
+-- !hash_left_join --
+1      \N      null    \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+2      \N      2       \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+3      \N      null    2019-09-09      \N      8.9     \N      \N      \N      
\N      \N      \N
+5      \N      null    \N      2019-09-09T00:00        8.9     \N      \N      
\N      \N      \N      \N
+
+-- !hash_inner_join --
+
 -- !cross_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
@@ -357,12 +389,20 @@ false     true    true    false   false
 3      \N      null    2019-09-09      \N      8.9     3       \N      null    
2019-09-09      \N      8.9
 5      \N      null    \N      2019-09-09T00:00        8.9     5       \N      
null    \N      2019-09-09T00:00        8.9
 
--- !hash_join --
+-- !hash_right_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      3       \N      null    
2019-09-09      \N      8.9
 \N     \N      \N      \N      \N      \N      5       \N      null    \N      
2019-09-09T00:00        8.9
 
+-- !hash_left_join --
+1      \N      null    \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+2      \N      2       \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+3      \N      null    2019-09-09      \N      8.9     \N      \N      \N      
\N      \N      \N
+5      \N      null    \N      2019-09-09T00:00        8.9     \N      \N      
\N      \N      \N      \N
+
+-- !hash_inner_join --
+
 -- !cross_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
@@ -412,12 +452,20 @@ false     true    true    false   false
 5      \N      null    \N      2019-09-09T00:00        8.9     3       \N      
null    2019-09-09      \N      8.9
 5      \N      null    \N      2019-09-09T00:00        8.9     5       \N      
null    \N      2019-09-09T00:00        8.9
 
--- !hash_join --
+-- !hash_right_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      3       \N      null    
2019-09-09      \N      8.9
 \N     \N      \N      \N      \N      \N      5       \N      null    \N      
2019-09-09T00:00        8.9
 
+-- !hash_left_join --
+1      \N      null    \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+2      \N      2       \N      \N      8.9     \N      \N      \N      \N      
\N      \N
+3      \N      null    2019-09-09      \N      8.9     \N      \N      \N      
\N      \N      \N
+5      \N      null    \N      2019-09-09T00:00        8.9     \N      \N      
\N      \N      \N      \N
+
+-- !hash_inner_join --
+
 -- !cross_join --
 \N     \N      \N      \N      \N      \N      1       \N      null    \N      
\N      8.9
 \N     \N      \N      \N      \N      \N      2       \N      2       \N      
\N      8.9
diff --git a/regression-test/suites/nereids_p0/join/test_join_15.groovy 
b/regression-test/suites/nereids_p0/join/test_join_15.groovy
index 9778e45eb16..22e6f8a06a3 100644
--- a/regression-test/suites/nereids_p0/join/test_join_15.groovy
+++ b/regression-test/suites/nereids_p0/join/test_join_15.groovy
@@ -192,7 +192,11 @@ suite("test_join_15", "nereids_p0") {
             order by a.k1, b.k1"""
         qt_right_join"""select * from ${null_table_1} a right join 
${null_table_1} b on  a.k${index}<=>b.k${index}
             order by a.k1, b.k1"""
-        qt_hash_join"""select * from ${null_table_1} a right join 
${null_table_1} b on  a.k${index}<=>b.k${index} and a.k2=b.k2
+        qt_hash_right_join"""select * from ${null_table_1} a right join 
${null_table_1} b on a.k${index}<=>b.k${index} and a.k2=b.k2
+            order by a.k1, b.k1"""
+        qt_hash_left_join"""select * from ${null_table_1} a left join 
${null_table_1} b on a.k${index}<=>b.k${index} and a.k2=b.k2
+            order by a.k1, b.k1"""
+        qt_hash_inner_join"""select * from ${null_table_1} a inner join 
${null_table_1} b on a.k${index}<=>b.k${index} and a.k2=b.k2
             order by a.k1, b.k1"""
         qt_cross_join"""select * from ${null_table_1} a right join 
${null_table_1} b on  a.k${index}<=>b.k${index} and a.k2 !=b.k2
             order by a.k1, b.k1"""


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to