This is an automated email from the ASF dual-hosted git repository.

zabetak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c2228e  [CALCITE-4574] Wrong/Invalid plans when using RelBuilder#join 
with correlations
8c2228e is described below

commit 8c2228eaf8ccc05ae58778276e760092557f78cc
Author: Stamatis Zampetakis <[email protected]>
AuthorDate: Fri Apr 9 19:04:20 2021 +0200

    [CALCITE-4574] Wrong/Invalid plans when using RelBuilder#join with 
correlations
    
    1. Gather required columns from the right side after the handling of the
    filter to account for those columns present in the join condition.
    2. Predicate for SEMI/ANTI join types should be pushed to the right
    cause otherwise columns in the condition referencing the right side will
    be invalid.
    3. Throw IllegalArgumentException for non-supported correlation joins.
    4. Update existing tests with the correct plans
    5. Add new tests for RelBuilder#join with correlation covering all join
    types.
    
    Close apache/calcite#2393
---
 .../java/org/apache/calcite/tools/RelBuilder.java  | 16 +++--
 .../rel/logical/ToLogicalConverterTest.java        |  2 +-
 .../org/apache/calcite/test/RelBuilderTest.java    | 68 +++++++++++++++++++++-
 .../org/apache/calcite/test/RelOptRulesTest.xml    |  6 +-
 4 files changed, 80 insertions(+), 12 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java 
b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index ba1d362..269a72f 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -2375,24 +2375,28 @@ public class RelBuilder {
     }
     if (correlate) {
       final CorrelationId id = Iterables.getOnlyElement(variablesSet);
-      final ImmutableBitSet requiredColumns =
-          RelOptUtil.correlationColumns(id, right.rel);
       if (!RelOptUtil.notContainsCorrelation(left.rel, id, Litmus.IGNORE)) {
         throw new IllegalArgumentException("variable " + id
             + " must not be used by left input to correlation");
       }
+      // Correlate does not have an ON clause.
       switch (joinType) {
       case LEFT:
-        // Correlate does not have an ON clause.
-        // For a LEFT correlate, predicate must be evaluated first.
-        // For INNER, we can defer.
+      case SEMI:
+      case ANTI:
+        // For a LEFT/SEMI/ANTI, predicate must be evaluated first.
         stack.push(right);
         filter(condition.accept(new Shifter(left.rel, id, right.rel)));
         right = stack.pop();
         break;
-      default:
+      case INNER:
+        // For INNER, we can defer.
         postCondition = condition;
+        break;
+      default:
+        throw new IllegalArgumentException("Correlated " + joinType + " join 
is not supported");
       }
+      final ImmutableBitSet requiredColumns = 
RelOptUtil.correlationColumns(id, right.rel);
       join =
           struct.correlateFactory.createCorrelate(left.rel, right.rel, id,
               requiredColumns, joinType);
diff --git 
a/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java 
b/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java
index 63e2636..aec9d95 100644
--- 
a/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java
+++ 
b/core/src/test/java/org/apache/calcite/rel/logical/ToLogicalConverterTest.java
@@ -333,7 +333,7 @@ class ToLogicalConverterTest {
             ImmutableSet.of(v.get().id))
         .build();
     String expectedPhysical = ""
-        + "EnumerableCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{}])\n"
+        + "EnumerableCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{5}])\n"
         + "  EnumerableTableScan(table=[[scott, EMP]])\n"
         + "  EnumerableFilter(condition=[=($cor0.SAL, 1000)])\n"
         + "    EnumerableTableScan(table=[[scott, DEPT]])\n";
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java 
b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 1d062f4..2586317 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -2336,7 +2336,7 @@ public class RelBuilderTest {
     // Note that the join filter gets pushed to the right-hand input of
     // LogicalCorrelate
     final String expected = ""
-        + "LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{7}])\n"
+        + "LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{5, 7}])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n"
         + "  LogicalFilter(condition=[=($cor0.SAL, 1000)])\n"
         + "    LogicalFilter(condition=[=($0, $cor0.DEPTNO)])\n"
@@ -3670,7 +3670,7 @@ public class RelBuilderTest {
 
     final String expected = ""
         + "LogicalCorrelate(correlation=[$cor0], joinType=[left], "
-        + "requiredColumns=[{7}])\n"
+        + "requiredColumns=[{5, 7}])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n"
         + "  LogicalFilter(condition=[=($cor0.SAL, 1000)])\n"
         + "    LogicalFilter(condition=[OR("
@@ -3870,6 +3870,70 @@ public class RelBuilderTest {
     assertThat(root, hasTree(expected));
   }
 
+  @Test void testSimpleSemiCorrelateViaJoin() {
+    RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.SEMI);
+    final String expected = ""
+        + "LogicalCorrelate(correlation=[$cor0], joinType=[semi], 
requiredColumns=[{7}])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n"
+        + "  LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+        + "    LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(root, hasTree(expected));
+  }
+
+  @Test void testSimpleAntiCorrelateViaJoin() {
+    RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.ANTI);
+    final String expected = ""
+        + "LogicalCorrelate(correlation=[$cor0], joinType=[anti], 
requiredColumns=[{7}])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n"
+        + "  LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+        + "    LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(root, hasTree(expected));
+  }
+
+  @Test void testSimpleLeftCorrelateViaJoin() {
+    RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.LEFT);
+    final String expected = ""
+        + "LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{7}])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n"
+        + "  LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+        + "    LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(root, hasTree(expected));
+  }
+
+  @Test void testSimpleInnerCorrelateViaJoin() {
+    RelNode root = buildSimpleCorrelateWithJoin(JoinRelType.INNER);
+    final String expected = ""
+        + "LogicalFilter(condition=[=($7, $8)])\n"
+        + "  LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{}])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(root, hasTree(expected));
+  }
+
+  @Test void testSimpleRightCorrelateViaJoinThrowsException() {
+    assertThrows(IllegalArgumentException.class,
+        () -> buildSimpleCorrelateWithJoin(JoinRelType.RIGHT));
+  }
+
+  @Test void testSimpleFullCorrelateViaJoinThrowsException() {
+    assertThrows(IllegalArgumentException.class,
+        () -> buildSimpleCorrelateWithJoin(JoinRelType.FULL));
+  }
+
+  private static RelNode buildSimpleCorrelateWithJoin(JoinRelType type) {
+    final RelBuilder builder = RelBuilder.create(config().build());
+    final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
+    return builder
+        .scan("EMP")
+        .variable(v)
+        .scan("DEPT")
+        .join(type,
+            builder.equals(
+                builder.field(2, 0, "DEPTNO"),
+                builder.field(2, 1, "DEPTNO")), ImmutableSet.of(v.get().id))
+        .build();
+  }
+
   @Test void testCorrelateWithComplexFields() {
     final RelBuilder builder = RelBuilder.create(config().build());
     final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index c30446f..4c509c3 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -10545,7 +10545,7 @@ LogicalProject(DEPTNO=[$0])
         <Resource name="planMid">
             <![CDATA[
 LogicalProject(SAL=[$5], EXPR$1=[IS NULL($10)])
-  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{2}])
+  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0, 
2}])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
     LogicalFilter(condition=[=($cor0.EMPNO, $0)])
       LogicalProject(DEPTNO=[$0], i=[true])
@@ -12307,7 +12307,7 @@ LogicalProject(SAL=[$5])
 LogicalProject(SAL=[$5])
   LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
     LogicalFilter(condition=[OR(=($9, 0), IS NOT TRUE(OR(IS NOT NULL($12), 
<($10, $9))))])
-      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{2}])
+      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{0, 2}])
         LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{2}])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
           LogicalProject(c=[$0], ck=[$0])
@@ -12384,7 +12384,7 @@ LogicalProject(EMPNO=[$1])
 LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
   LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
     LogicalFilter(condition=[OR(=($9, 0), IS NOT TRUE(OR(IS NOT NULL($12), 
<($10, $9))))])
-      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{1}])
+      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{0, 1}])
         LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{1}])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
           LogicalProject(c=[$0], ck=[$0])

Reply via email to