This is an automated email from the ASF dual-hosted git repository.

chunwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 4969b9690 [CALCITE-5073] JoinConditionPushRule cannot infer 'LHS.C1 = 
LHS.C2' from 'LHS.C1 = RHS.C1 AND LHS.C2 = RHS.C1'
4969b9690 is described below

commit 4969b9690efe999c522daf1151c4a00a33be0eb0
Author: Benchao Li <[email protected]>
AuthorDate: Fri Apr 1 18:12:27 2022 +0800

    [CALCITE-5073] JoinConditionPushRule cannot infer 'LHS.C1 = LHS.C2' from 
'LHS.C1 = RHS.C1 AND LHS.C2 = RHS.C1'
    
    Cosmetic fix-ups by Chunwei Lei.
    
    Close apache/calcite#2761
---
 .../apache/calcite/rel/rules/FilterJoinRule.java   | 136 ++++++++++++++++++++-
 .../calcite/rel/rel2sql/RelToSqlConverterTest.java |  15 +--
 .../org/apache/calcite/test/RelOptRulesTest.java   |  34 ++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    |  55 +++++++++
 4 files changed, 230 insertions(+), 10 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java 
b/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java
index ea02f8276..0df08263a 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/FilterJoinRule.java
@@ -27,8 +27,11 @@ import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 
@@ -40,7 +43,9 @@ import org.immutables.value.Value;
 
 import java.util.ArrayList;
 import java.util.Iterator;
+import java.util.LinkedHashSet;
 import java.util.List;
+import java.util.Set;
 
 import static org.apache.calcite.plan.RelOptUtil.conjunctions;
 
@@ -67,7 +72,7 @@ public abstract class FilterJoinRule<C extends 
FilterJoinRule.Config>
 
   protected void perform(RelOptRuleCall call, @Nullable Filter filter,
       Join join) {
-    final List<RexNode> joinFilters =
+    List<RexNode> joinFilters =
         RelOptUtil.conjunctions(join.getCondition());
     final List<RexNode> origJoinFilters = ImmutableList.copyOf(joinFilters);
 
@@ -133,6 +138,8 @@ public abstract class FilterJoinRule<C extends 
FilterJoinRule.Config>
       }
     }
 
+    joinFilters = inferJoinEqualConditions(joinFilters, join);
+
     // Try to push down filters in ON clause. A ON clause filter can only be
     // pushed down if it does not affect the non-matching set, i.e. it is
     // not on the side which is preserved.
@@ -225,6 +232,133 @@ public abstract class FilterJoinRule<C extends 
FilterJoinRule.Config>
     call.transformTo(relBuilder.build());
   }
 
+  /**
+   * Infers more equal conditions for the join condition.
+   *
+   * <p> For example, in {@code SELECT * FROM T1, T2, T3 WHERE T1.id = T3.id 
AND T2.id = T3.id},
+   * we can infer {@code T1.id = T2.id} for the first Join node from second 
Join node's condition:
+   * {@code T1.id = T3.id AND T2.id = T3.id}.
+   *
+   * <p>For the above SQL, the second Join's condition is {@code T1.id = T3.id 
AND T2.id = T3.id}.
+   * After inference, the final condition would be: {@code T1.id = T2.id AND 
T1.id = T3.id}, the
+   * {@code T1.id = T2.id} can be further pushed into LHS.
+   *
+   * @param rexNodes the Join condition
+   * @param join the Join node
+   * @return the newly inferred conditions
+   */
+  protected List<RexNode> inferJoinEqualConditions(List<RexNode> rexNodes, 
Join join) {
+    final List<RexNode> result = new ArrayList<>(rexNodes.size());
+    final List<Set<Integer>> equalSets = splitEqualSets(rexNodes, result);
+
+    boolean needOptimize = false;
+    for (Set<Integer> set : equalSets) {
+      if (set.size() > 2) {
+        needOptimize = true;
+        break;
+      }
+    }
+    if (!needOptimize) {
+      // Keep the conditions unchanged.
+      return rexNodes;
+    }
+
+    result.addAll(constructConditionFromEqualSets(join, equalSets));
+    return result;
+  }
+
+  /**
+   * Splits out the equal sets.
+   *
+   * @param rexNodes the original conditions
+   * @param leftNodes where the conditions not feasible for equal sets are put
+   * @return the equal sets
+   */
+  private List<Set<Integer>> splitEqualSets(List<RexNode> rexNodes, 
List<RexNode> leftNodes) {
+    final List<Set<Integer>> equalSets = new ArrayList<>();
+    for (RexNode rexNode : rexNodes) {
+      if (rexNode.isA(SqlKind.EQUALS)) {
+        final RexNode op1 = ((RexCall) rexNode).getOperands().get(0);
+        final RexNode op2 = ((RexCall) rexNode).getOperands().get(1);
+        if (op1 instanceof RexInputRef && op2 instanceof RexInputRef) {
+          final RexInputRef in1 = (RexInputRef) op1;
+          final RexInputRef in2 = (RexInputRef) op2;
+          Set<Integer> set = null;
+          for (Set<Integer> s : equalSets) {
+            if (s.contains(in1.getIndex()) || s.contains(in2.getIndex())) {
+              set = s;
+              break;
+            }
+          }
+          if (set == null) {
+            // To make the result deterministic.
+            set = new LinkedHashSet<>();
+            equalSets.add(set);
+          }
+          set.add(in1.getIndex());
+          set.add(in2.getIndex());
+        } else {
+          leftNodes.add(rexNode);
+        }
+      } else {
+        leftNodes.add(rexNode);
+      }
+    }
+
+    return equalSets;
+  }
+
+  /**
+   * Constructs new equal conditions from the equal sets.
+   *
+   * @param join the original {@link Join} node
+   * @param equalSets the equal sets
+   * @return the newly constructed conditions from equal sets
+   */
+  private List<RexNode> constructConditionFromEqualSets(Join join, 
List<Set<Integer>> equalSets) {
+    final RexBuilder rexBuilder = join.getCluster().getRexBuilder();
+    final List<RexNode> result = new ArrayList<>();
+    final int leftRowCount = join.getLeft().getRowType().getFieldCount();
+    for (Set<Integer> set : equalSets) {
+      final List<Integer> leftSet = new ArrayList<>();
+      final List<Integer> rightSet = new ArrayList<>();
+      for (int i : set) {
+        if (i < leftRowCount) {
+          leftSet.add(i);
+        } else {
+          rightSet.add(i);
+        }
+      }
+      // Add left side conditions.
+      if (leftSet.size() > 1) {
+        for (int i = 1; i < leftSet.size(); ++i) {
+          result.add(
+              rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+                  rexBuilder.makeInputRef(join, leftSet.get(0)),
+                  rexBuilder.makeInputRef(join, leftSet.get(i))));
+        }
+      }
+      // Add right side conditions.
+      if (rightSet.size() > 1) {
+        for (int i = 1; i < rightSet.size(); ++i) {
+          result.add(
+              rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+                  rexBuilder.makeInputRef(join, rightSet.get(0)),
+                  rexBuilder.makeInputRef(join, rightSet.get(i))));
+        }
+      }
+      // Only need one equal condition for each equal set.
+      if (leftSet.size() > 0 && rightSet.size() > 0) {
+        result.add(
+            rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+                rexBuilder.makeInputRef(join, leftSet.get(0)),
+                rexBuilder.makeInputRef(join, rightSet.get(0))));
+      }
+    }
+
+    return result;
+  }
+
   /**
    * Get conjunctions of filter's condition but with collapsed
    * {@code IS NOT DISTINCT FROM} expressions if needed.
diff --git 
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java 
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index 4f6ee5d83..b3fbe2a56 100644
--- 
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++ 
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -5882,8 +5882,8 @@ class RelToSqlConverterTest {
         + "(SELECT \"department_id\"\n"
         + "FROM \"foodmart\".\"employee\"\n"
         + "GROUP BY \"department_id\") \"t1\"\n"
-        + "GROUP BY \"t1\".\"department_id\") \"t3\" ON 
\"employee\".\"department_id\" = \"t3\".\"department_id0\""
-        + " AND \"employee\".\"department_id\" = \"t3\".\"EXPR$0\"";
+        + "GROUP BY \"t1\".\"department_id\"\n"
+        + "HAVING \"t1\".\"department_id\" = MIN(\"t1\".\"department_id\")) 
\"t4\" ON \"employee\".\"department_id\" = \"t4\".\"department_id0\"";
     sql(query).withOracle().ok(expected);
   }
 
@@ -5893,18 +5893,15 @@ class RelToSqlConverterTest {
         + " where A.\"department_id\" = ( select min( A.\"department_id\") 
from \"foodmart\".\"department\" B where 1=2 )";
     final String expected = "SELECT \"employee\".\"department_id\"\n"
         + "FROM \"foodmart\".\"employee\"\n"
-        + "INNER JOIN (SELECT \"t1\".\"department_id\" AS \"department_id0\","
-        + " MIN(\"t1\".\"department_id\") AS \"EXPR$0\"\n"
+        + "INNER JOIN (SELECT \"t1\".\"department_id\" AS \"department_id0\", 
MIN(\"t1\".\"department_id\") AS \"EXPR$0\"\n"
         + "FROM (SELECT *\n"
-        + "FROM (VALUES (NULL, NULL))"
-        + " AS \"t\" (\"department_id\", \"department_description\")\n"
+        + "FROM (VALUES (NULL, NULL)) AS \"t\" (\"department_id\", 
\"department_description\")\n"
         + "WHERE 1 = 0) AS \"t\",\n"
         + "(SELECT \"department_id\"\n"
         + "FROM \"foodmart\".\"employee\"\n"
         + "GROUP BY \"department_id\") AS \"t1\"\n"
-        + "GROUP BY \"t1\".\"department_id\") AS \"t3\" "
-        + "ON \"employee\".\"department_id\" = \"t3\".\"department_id0\""
-        + " AND \"employee\".\"department_id\" = \"t3\".\"EXPR$0\"";
+        + "GROUP BY \"t1\".\"department_id\"\n"
+        + "HAVING \"t1\".\"department_id\" = MIN(\"t1\".\"department_id\")) AS 
\"t4\" ON \"employee\".\"department_id\" = \"t4\".\"department_id0\"";
     sql(query).ok(expected);
   }
 
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index db759c298..dc4b1a41d 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -2148,6 +2148,40 @@ class RelOptRulesTest extends RelOptTestBase {
         .check();
   }
 
+  /**
+   * Test case for <a 
href="https://issues.apache.org/jira/browse/CALCITE-5073";>[CALCITE-5073]
+   * JoinConditionPushRule cannot infer 'LHS.C1 = LHS.C2' from
+   * 'LHS.C1 = RHS.C1 AND LHS.C2 = RHS.C1'</a>.
+   */
+  @Test void testJoinConditionPushdown1() {
+    final String sql = "select *\n"
+        + "from emp e1, emp e2, dept d2\n"
+        + "where e1.deptno = d2.deptno and e2.deptno = d2.deptno";
+    sql(sql)
+        .withRule(CoreRules.FILTER_INTO_JOIN,
+            CoreRules.JOIN_CONDITION_PUSH,
+            CoreRules.PROJECT_MERGE,
+            CoreRules.FILTER_PROJECT_TRANSPOSE)
+        .check();
+  }
+
+  /**
+   * Test case for <a 
href="https://issues.apache.org/jira/browse/CALCITE-5073";>[CALCITE-5073]
+   * JoinConditionPushRule cannot infer 'LHS.C1 = LHS.C2' from
+   * 'LHS.C1 = RHS.C1 AND LHS.C2 = RHS.C1'</a>.
+   */
+  @Test void testJoinConditionPushdown2() {
+    final String sql = "select *\n"
+        + "from emp e, dept d\n"
+        + "where e.deptno = d.deptno and e.empno = d.deptno";
+    sql(sql)
+        .withRule(CoreRules.FILTER_INTO_JOIN,
+            CoreRules.JOIN_CONDITION_PUSH,
+            CoreRules.PROJECT_MERGE,
+            CoreRules.FILTER_PROJECT_TRANSPOSE)
+        .check();
+  }
+
   /** Tests that filters are combined if they are identical. */
   @Test void testMergeFilter() {
     final String sql = "select name from (\n"
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 70819a8ba..fa18df0f9 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -4442,6 +4442,61 @@ LogicalProject(EMPNO=[$3], ENAME=[$4], JOB=[$5], 
MGR=[$6], HIREDATE=[$7], SAL=[$
 LogicalJoin(condition=[true], joinType=[inner])
   LogicalTableScan(table=[[scott, EMP]])
   LogicalTableScan(table=[[scott, DEPT]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testJoinConditionPushdown1">
+    <Resource name="sql">
+      <![CDATA[select *
+from emp e1, emp e2, dept d2
+where e1.deptno = d2.deptno and e2.deptno = d2.deptno
+]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], EMPNO0=[$9], ENAME0=[$10], 
JOB0=[$11], MGR0=[$12], HIREDATE0=[$13], SAL0=[$14], COMM0=[$15], 
DEPTNO0=[$16], SLACKER0=[$17], DEPTNO1=[$18], NAME=[$19])
+  LogicalFilter(condition=[AND(=($7, $18), =($16, $18))])
+    LogicalJoin(condition=[true], joinType=[inner])
+      LogicalJoin(condition=[true], joinType=[inner])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], EMPNO0=[$9], ENAME0=[$10], 
JOB0=[$11], MGR0=[$12], HIREDATE0=[$13], SAL0=[$14], COMM0=[$15], 
DEPTNO0=[$16], SLACKER0=[$17], DEPTNO1=[$18], NAME=[$19])
+  LogicalJoin(condition=[=($7, $18)], joinType=[inner])
+    LogicalJoin(condition=[=($7, $16)], joinType=[inner])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testJoinConditionPushdown2">
+    <Resource name="sql">
+      <![CDATA[select *
+from emp e, dept d
+where e.deptno = d.deptno and e.empno = d.deptno
+]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$9], NAME=[$10])
+  LogicalFilter(condition=[AND(=($7, $9), =($0, $9))])
+    LogicalJoin(condition=[true], joinType=[inner])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$9], NAME=[$10])
+  LogicalJoin(condition=[=($7, $9)], joinType=[inner])
+    LogicalFilter(condition=[=($7, $0)])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
     </Resource>
   </TestCase>

Reply via email to