Repository: calcite
Updated Branches:
  refs/heads/master 46654ad2a -> 565d63926


[CALCITE-1495] SemiJoinRule should not apply to RIGHT and FULL JOIN, and should 
strip LEFT JOIN

Also, update SemiJoinRule to accept RelBuilderFactory and class parameters.


Project: http://git-wip-us.apache.org/repos/asf/calcite/repo
Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/565d6392
Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/565d6392
Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/565d6392

Branch: refs/heads/master
Commit: 565d639261d126462a78e57ff7f21527f38697a9
Parents: bac9ee7
Author: Julian Hyde <jh...@apache.org>
Authored: Thu Dec 1 16:02:13 2016 -0800
Committer: Julian Hyde <jh...@apache.org>
Committed: Thu Dec 1 20:10:18 2016 -0800

----------------------------------------------------------------------
 .../apache/calcite/rel/rules/SemiJoinRule.java  |  79 +++++++++----
 .../apache/calcite/test/RelOptRulesTest.java    |  95 ++++++++++++++-
 .../org/apache/calcite/test/RelOptRulesTest.xml | 115 ++++++++++++++++---
 core/src/test/resources/sql/misc.iq             |  29 +++--
 4 files changed, 271 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/calcite/blob/565d6392/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java 
b/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
index db87da3..13ad991 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
@@ -25,12 +25,15 @@ import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.Join;
 import org.apache.calcite.rel.core.JoinInfo;
 import org.apache.calcite.rel.core.Project;
-import org.apache.calcite.rel.core.SemiJoin;
+import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.ImmutableIntList;
 
+import com.google.common.base.Predicate;
 import com.google.common.collect.Lists;
 
 import java.util.List;
@@ -41,15 +44,34 @@ import java.util.List;
  * {@link org.apache.calcite.rel.logical.LogicalAggregate}.
  */
 public class SemiJoinRule extends RelOptRule {
-  public static final SemiJoinRule INSTANCE = new SemiJoinRule();
+  private static final Predicate<Join> IS_LEFT_OR_INNER =
+      new Predicate<Join>() {
+        public boolean apply(Join input) {
+          switch (input.getJoinType()) {
+          case LEFT:
+          case INNER:
+            return true;
+          default:
+            return false;
+          }
+        }
+      };
 
-  private SemiJoinRule() {
+  public static final SemiJoinRule INSTANCE =
+      new SemiJoinRule(Project.class, Join.class, Aggregate.class,
+          RelFactories.LOGICAL_BUILDER, "SemiJoinRule");
+
+  /** Creates a SemiJoinRule. */
+  public SemiJoinRule(Class<Project> projectClass, Class<Join> joinClass,
+      Class<Aggregate> aggregateClass, RelBuilderFactory relBuilderFactory,
+      String description) {
     super(
-        operand(Project.class,
+        operand(projectClass,
             some(
-                operand(Join.class,
+                operand(joinClass, null, IS_LEFT_OR_INNER,
                     some(operand(RelNode.class, any()),
-                        operand(Aggregate.class, any()))))));
+                        operand(aggregateClass, any()))))),
+        relBuilderFactory, description);
   }
 
   @Override public void onMatch(RelOptRuleCall call) {
@@ -77,24 +99,35 @@ public class SemiJoinRule extends RelOptRule {
     if (!joinInfo.isEqui()) {
       return;
     }
-    final List<Integer> newRightKeyBuilder = Lists.newArrayList();
-    final List<Integer> aggregateKeys = aggregate.getGroupSet().asList();
-    for (int key : joinInfo.rightKeys) {
-      newRightKeyBuilder.add(aggregateKeys.get(key));
+    final RelBuilder relBuilder = call.builder();
+    relBuilder.push(left);
+    switch (join.getJoinType()) {
+    case INNER:
+      final List<Integer> newRightKeyBuilder = Lists.newArrayList();
+      final List<Integer> aggregateKeys = aggregate.getGroupSet().asList();
+      for (int key : joinInfo.rightKeys) {
+        newRightKeyBuilder.add(aggregateKeys.get(key));
+      }
+      final ImmutableIntList newRightKeys = 
ImmutableIntList.copyOf(newRightKeyBuilder);
+      relBuilder.push(aggregate.getInput());
+      final RexNode newCondition =
+          RelOptUtil.createEquiJoinCondition(relBuilder.peek(2, 0),
+              joinInfo.leftKeys, relBuilder.peek(2, 1), newRightKeys,
+              rexBuilder);
+      relBuilder.semiJoin(newCondition);
+      break;
+
+    case LEFT:
+      // The right-hand side produces no more than 1 row (because of the
+      // Aggregate) and no fewer than 1 row (because of LEFT), and therefore
+      // we can eliminate the semi-join.
+      break;
+
+    default:
+      throw new AssertionError(join.getJoinType());
     }
-    final ImmutableIntList newRightKeys =
-        ImmutableIntList.copyOf(newRightKeyBuilder);
-    final RelNode newRight = aggregate.getInput();
-    final RexNode newCondition =
-        RelOptUtil.createEquiJoinCondition(left, joinInfo.leftKeys, newRight,
-            newRightKeys, rexBuilder);
-    final SemiJoin semiJoin =
-        SemiJoin.create(left, newRight, newCondition, joinInfo.leftKeys,
-            newRightKeys);
-    final Project newProject =
-        project.copy(project.getTraitSet(), semiJoin, project.getProjects(),
-            project.getRowType());
-    call.transformTo(ProjectRemoveRule.strip(newProject));
+    relBuilder.project(project.getProjects(), 
project.getRowType().getFieldNames());
+    call.transformTo(relBuilder.build());
   }
 }
 

http://git-wip-us.apache.org/repos/asf/calcite/blob/565d6392/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 3097d47..ebc0e99 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -493,7 +493,7 @@ public class RelOptRulesTest extends RelOptTestBase {
     checkPlanning(program, sql);
   }
 
-  @Test public void testSemiJoinRule() {
+  @Test public void testSemiJoinRuleExists() {
     final HepProgram preProgram =
         HepProgram.builder()
             .addRuleInstance(FilterProjectTransposeRule.INSTANCE)
@@ -516,6 +516,99 @@ public class RelOptRulesTest extends RelOptTestBase {
         .check();
   }
 
+  @Test public void testSemiJoinRule() {
+    final HepProgram preProgram =
+        HepProgram.builder()
+            .addRuleInstance(FilterProjectTransposeRule.INSTANCE)
+            .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN)
+            .addRuleInstance(ProjectMergeRule.INSTANCE)
+            .build();
+    final HepProgram program =
+        HepProgram.builder()
+            .addRuleInstance(SemiJoinRule.INSTANCE)
+            .build();
+    final String sql = "select dept.* from dept join (\n"
+        + "  select distinct deptno from emp\n"
+        + "  where sal > 100) using (deptno)";
+    sql(sql)
+        .withDecorrelation(true)
+        .withTrim(true)
+        .withPre(preProgram)
+        .with(program)
+        .check();
+  }
+
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-1495";>[CALCITE-1495]
+   * SemiJoinRule should not apply to RIGHT and FULL JOIN</a>. */
+  @Test public void testSemiJoinRuleRight() {
+    final HepProgram preProgram =
+        HepProgram.builder()
+            .addRuleInstance(FilterProjectTransposeRule.INSTANCE)
+            .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN)
+            .addRuleInstance(ProjectMergeRule.INSTANCE)
+            .build();
+    final HepProgram program =
+        HepProgram.builder()
+            .addRuleInstance(SemiJoinRule.INSTANCE)
+            .build();
+    final String sql = "select dept.* from dept right join (\n"
+        + "  select distinct deptno from emp\n"
+        + "  where sal > 100) using (deptno)";
+    sql(sql)
+        .withPre(preProgram)
+        .with(program)
+        .withDecorrelation(true)
+        .withTrim(true)
+        .checkUnchanged();
+  }
+
+  /** Similar to {@link #testSemiJoinRuleRight()} but FULL. */
+  @Test public void testSemiJoinRuleFull() {
+    final HepProgram preProgram =
+        HepProgram.builder()
+            .addRuleInstance(FilterProjectTransposeRule.INSTANCE)
+            .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN)
+            .addRuleInstance(ProjectMergeRule.INSTANCE)
+            .build();
+    final HepProgram program =
+        HepProgram.builder()
+            .addRuleInstance(SemiJoinRule.INSTANCE)
+            .build();
+    final String sql = "select dept.* from dept full join (\n"
+        + "  select distinct deptno from emp\n"
+        + "  where sal > 100) using (deptno)";
+    sql(sql)
+        .withPre(preProgram)
+        .with(program)
+        .withDecorrelation(true)
+        .withTrim(true)
+        .checkUnchanged();
+  }
+
+  /** Similar to {@link #testSemiJoinRule()} but LEFT. */
+  @Test public void testSemiJoinRuleLeft() {
+    final HepProgram preProgram =
+        HepProgram.builder()
+            .addRuleInstance(FilterProjectTransposeRule.INSTANCE)
+            .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN)
+            .addRuleInstance(ProjectMergeRule.INSTANCE)
+            .build();
+    final HepProgram program =
+        HepProgram.builder()
+            .addRuleInstance(SemiJoinRule.INSTANCE)
+            .build();
+    final String sql = "select name from dept left join (\n"
+        + "  select distinct deptno from emp\n"
+        + "  where sal > 100) using (deptno)";
+    sql(sql)
+        .withPre(preProgram)
+        .with(program)
+        .withDecorrelation(true)
+        .withTrim(true)
+        .check();
+  }
+
   /** Test case for
    * <a href="https://issues.apache.org/jira/browse/CALCITE-438";>[CALCITE-438]
    * Push predicates through SemiJoin</a>. */

http://git-wip-us.apache.org/repos/asf/calcite/blob/565d6392/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
----------------------------------------------------------------------
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index ba017bd..5313698 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -3842,7 +3842,7 @@ LogicalProject(DEPTNO=[$0], NAME=[$1], EMPNO=[$2], 
ENAME=[$3], JOB=[$4], MGR=[$5
 ]]>
         </Resource>
     </TestCase>
-    <TestCase name="testSemiJoinRule">
+    <TestCase name="testSemiJoinRuleExists">
         <Resource name="sql">
             <![CDATA[select * from dept where exists (
   select * from emp
@@ -5130,6 +5130,93 @@ LogicalProject(DEPTNO=[$3], SUM_SAL=[$1], C=[$2])
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testSemiJoinRule">
+        <Resource name="sql">
+            <![CDATA[select dept.* from dept join (
+  select distinct deptno from emp
+  where sal > 100) using (deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(DEPTNO=[$0], NAME=[$1])
+  LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+    LogicalAggregate(group=[{0}])
+      LogicalProject(DEPTNO=[$7])
+        LogicalFilter(condition=[>($5, 100)])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+SemiJoin(condition=[=($0, $2)], joinType=[inner])
+  LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+  LogicalProject(DEPTNO=[$7])
+    LogicalFilter(condition=[>($5, 100)])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testSemiJoinRuleFull">
+        <Resource name="sql">
+            <![CDATA[select dept.* from dept full join (
+  select distinct deptno from emp
+  where sal > 100) using (deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(DEPTNO=[$0], NAME=[$1])
+  LogicalJoin(condition=[=($0, $2)], joinType=[full])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+    LogicalAggregate(group=[{0}])
+      LogicalProject(DEPTNO=[$7])
+        LogicalFilter(condition=[>($5, 100)])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testSemiJoinRuleLeft">
+        <Resource name="sql">
+            <![CDATA[select name from dept left join (
+  select distinct deptno from emp
+  where sal > 100) using (deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(NAME=[$1])
+  LogicalJoin(condition=[=($0, $2)], joinType=[left])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+    LogicalAggregate(group=[{0}])
+      LogicalProject(DEPTNO=[$7])
+        LogicalFilter(condition=[>($5, 100)])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(NAME=[$1])
+  LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testSemiJoinRuleRight">
+        <Resource name="sql">
+            <![CDATA[select dept.* from dept right join (
+  select distinct deptno from emp
+  where sal > 100) using (deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(DEPTNO=[$0], NAME=[$1])
+  LogicalJoin(condition=[=($0, $2)], joinType=[right])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+    LogicalAggregate(group=[{0}])
+      LogicalProject(DEPTNO=[$7])
+        LogicalFilter(condition=[>($5, 100)])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
     <TestCase name="testSortJoinTranspose1">
         <Resource name="sql">
             <![CDATA[select * from sales.emp e left join (
@@ -6192,7 +6279,10 @@ LogicalProject(JOB=[$0], EMPNO=[10], SAL=[$1], S=[$2])
     </TestCase>
     <TestCase name="testWhereNotInCorrelated">
         <Resource name="sql">
-            <![CDATA[select sal from emp where empno NOT IN (select deptno 
from dept where emp.job = dept.name)]]>
+            <![CDATA[select sal from emp
+where empno NOT IN (
+  select deptno from dept
+  where emp.job = dept.name)]]>
         </Resource>
         <Resource name="planBefore">
             <![CDATA[
@@ -6209,20 +6299,15 @@ LogicalProject(DEPTNO=[$0])
             <![CDATA[
 LogicalProject(SAL=[$5])
   LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
-    LogicalFilter(condition=[IS NULL($11)])
-      LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], 
DEPTNO0=[CAST($9):INTEGER], JOB0=[CAST($10):VARCHAR(10) CHARACTER SET 
"ISO-8859-1" COLLATE "ISO-8859-1$en_US$primary"], $f2=[CAST($11):BOOLEAN])
-        LogicalJoin(condition=[AND(=($2, $10), =($0, $9))], joinType=[inner])
+    LogicalFilter(condition=[IS NULL($10)])
+      LogicalFilter(condition=[=($0, $9)])
+        LogicalCorrelate(correlation=[$cor0], joinType=[LEFT], 
requiredColumns=[{2}])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
-          LogicalProject(DEPTNO=[$0], JOB=[$1], $f2=[true])
-            LogicalAggregate(group=[{0, 1}])
-              LogicalProject(DEPTNO=[$0], JOB=[$2], i=[$1])
-                LogicalProject(DEPTNO=[$0], i=[true], JOB=[$1])
-                  LogicalProject(DEPTNO=[$0], JOB=[$2])
-                    LogicalJoin(condition=[=($2, $1)], joinType=[inner])
-                      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
-                      LogicalAggregate(group=[{0}])
-                        LogicalProject(JOB=[$2])
-                          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+          LogicalAggregate(group=[{0, 1}])
+            LogicalProject(DEPTNO=[$0], i=[true])
+              LogicalProject(DEPTNO=[$0])
+                LogicalFilter(condition=[=($cor0.JOB, $1)])
+                  LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
         </Resource>
     </TestCase>

http://git-wip-us.apache.org/repos/asf/calcite/blob/565d6392/core/src/test/resources/sql/misc.iq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/misc.iq 
b/core/src/test/resources/sql/misc.iq
index 0378e26..972b73c 100644
--- a/core/src/test/resources/sql/misc.iq
+++ b/core/src/test/resources/sql/misc.iq
@@ -424,18 +424,31 @@ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[IS 
NULL($t5)], expr#9=[IS NULL($t7)
         EnumerableJoin(condition=[=($1, $3)], joinType=[inner])
           EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):INTEGER NOT 
NULL], proj#0..1=[{exprs}])
             EnumerableAggregate(group=[{0}])
-              EnumerableSemiJoin(condition=[=($1, $2)], joinType=[inner])
-                EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}])
-                  EnumerableTableScan(table=[[hr, emps]])
-                EnumerableJoin(condition=[=($0, $1)], joinType=[inner])
-                  EnumerableAggregate(group=[{1}])
-                    EnumerableTableScan(table=[[hr, emps]])
-                  EnumerableCalc(expr#0..3=[{inputs}], deptno=[$t0])
-                    EnumerableTableScan(table=[[hr, depts]])
+              EnumerableTableScan(table=[[hr, emps]])
           EnumerableCalc(expr#0..3=[{inputs}], expr#4=[90], expr#5=[+($t0, 
$t4)], deptno=[$t0], $f1=[$t5])
             EnumerableTableScan(table=[[hr, depts]])
 !plan
 
+# Left join to a relation with one row is recognized as a trivial semi-join
+# and eliminated.
+select e."deptno"
+from "hr"."emps" as e
+left join (select count(*) from "hr"."depts") on true;
++--------+
+| deptno |
++--------+
+|     10 |
+|     10 |
+|     10 |
+|     20 |
++--------+
+(4 rows)
+
+!ok
+EnumerableCalc(expr#0..4=[{inputs}], deptno=[$t1])
+  EnumerableTableScan(table=[[hr, emps]])
+!plan
+
 # Filter combined with an OR filter.
 select * from (
   select * from "hr"."emps" as e

Reply via email to