This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 3aee0b86aa [CALCITE-5732] EnumerableHashJoin and EnumerableMergeJoin 
on composite key return rows matching condition 'null = null'
3aee0b86aa is described below

commit 3aee0b86aa23476cbdecc75ad5d43b936a6fff7b
Author: rubenada <rube...@gmail.com>
AuthorDate: Wed Jul 12 14:45:05 2023 +0100

    [CALCITE-5732] EnumerableHashJoin and EnumerableMergeJoin on composite key 
return rows matching condition 'null = null'
---
 .../adapter/enumerable/EnumerableHashJoin.java     |   8 +-
 .../adapter/enumerable/EnumerableMergeJoin.java    |   7 +-
 .../calcite/adapter/enumerable/PhysType.java       |  26 +++-
 .../calcite/adapter/enumerable/PhysTypeImpl.java   | 168 +++++++++++++--------
 .../java/org/apache/calcite/runtime/Utilities.java |  11 ++
 .../org/apache/calcite/util/BuiltInMethod.java     |   2 +-
 .../apache/calcite/runtime/EnumerablesTest.java    |  28 ++--
 .../test/enumerable/EnumerableHashJoinTest.java    | 136 ++++++++++++++---
 .../test/enumerable/EnumerableJoinTest.java        | 101 ++++++++++---
 .../apache/calcite/linq4j/EnumerableDefaults.java  | 105 ++++++++-----
 10 files changed, 430 insertions(+), 162 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
index 259fb68503..cfe44da75d 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
@@ -218,8 +218,8 @@ public class EnumerableHashJoin extends Join implements 
EnumerableRel {
                 Expressions.list(
                     leftExpression,
                     rightExpression,
-                    leftResult.physType.generateAccessor(joinInfo.leftKeys),
-                    rightResult.physType.generateAccessor(joinInfo.rightKeys),
+                    
leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys),
+                    
rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys),
                     Util.first(keyPhysType.comparer(),
                         Expressions.constant(null)),
                     predicate)))
@@ -264,8 +264,8 @@ public class EnumerableHashJoin extends Join implements 
EnumerableRel {
                 BuiltInMethod.HASH_JOIN.method,
                 Expressions.list(
                     rightExpression,
-                    leftResult.physType.generateAccessor(joinInfo.leftKeys),
-                    rightResult.physType.generateAccessor(joinInfo.rightKeys),
+                    
leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys),
+                    
rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys),
                     EnumUtils.joinSelector(joinType,
                         physType,
                         ImmutableList.of(
diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
index d3f45d8d31..3f42e065c8 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
@@ -492,7 +492,7 @@ public class EnumerableMergeJoin extends Join implements 
EnumerableRel {
               RelFieldCollation.NullDirection.LAST));
     }
     final RelCollation collation = RelCollations.of(fieldCollations);
-    final Expression comparator = 
leftKeyPhysType.generateComparator(collation);
+    final Expression comparator = 
leftKeyPhysType.generateMergeJoinComparator(collation);
 
     return implementor.result(
         physType,
@@ -512,6 +512,9 @@ public class EnumerableMergeJoin extends Join implements 
EnumerableRel {
                         ImmutableList.of(
                             leftResult.physType, rightResult.physType)),
                     Expressions.constant(EnumUtils.toLinq4jJoinType(joinType)),
-                    comparator))).toBlock());
+                    comparator,
+                    Util.first(
+                        leftKeyPhysType.comparer(),
+                        Expressions.constant(null))))).toBlock());
   }
 }
diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java
index 4447e91dfe..8d4eeb176c 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java
@@ -112,11 +112,28 @@ public interface PhysType {
    *    public Object[] apply(Employee v1) {
    *        return FlatLists.of(v1.&lt;fieldN&gt;, v1.&lt;fieldM&gt;);
    *    }
-   * }
    * }</pre></blockquote>
    */
   Expression generateAccessor(List<Integer> fields);
 
+  /** Similar to {@link #generateAccessor(List)}, but if one of the fields is 
<code>null</code>,
+   * it will return <code>null</code>.
+   *
+   * <p>For example:
+   *
+   * <blockquote><pre>
+   * new Function1&lt;Employee, Object[]&gt; {
+   *    public Object[] apply(Employee v1) {
+   *        return v1.&lt;fieldN&gt; == null
+   *            ? null
+   *            : v1.&lt;fieldM&gt; == null
+   *                ? null
+   *                : FlatLists.of(v1.&lt;fieldN&gt;, v1.&lt;fieldM&gt;);
+   *    }
+   * }</pre></blockquote>
+   */
+  Expression generateAccessorWithoutNulls(List<Integer> fields);
+
   /** Generates a selector for the given fields from an expression, with the
    * default row format. */
   Expression generateSelector(
@@ -181,6 +198,13 @@ public interface PhysType {
   Expression generateComparator(
       RelCollation collation);
 
+  /** Similar to {@link #generateComparator(RelCollation)}, but with some 
specificities for
+   * MergeJoin algorithm: it will not consider two <code>null</code> values as 
equal.
+   *
+   * @see 
org.apache.calcite.linq4j.EnumerableDefaults#compareNullsLastForMergeJoin
+   */
+  Expression generateMergeJoinComparator(RelCollation collation);
+
   /** Returns a expression that yields a comparer, or null if this type
    * is comparable. */
   @Nullable Expression comparer();
diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java
index 55a44b2a29..ab51dd3e35 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java
@@ -407,6 +407,24 @@ public class PhysTypeImpl implements PhysType {
   }
 
   @Override public Expression generateComparator(RelCollation collation) {
+    return this.generateComparator(collation, fieldCollation -> {
+      final int index = fieldCollation.getFieldIndex();
+      final boolean nullsFirst =
+          fieldCollation.nullDirection
+              == RelFieldCollation.NullDirection.FIRST;
+      final boolean descending =
+          fieldCollation.getDirection()
+              == RelFieldCollation.Direction.DESCENDING;
+      return fieldNullable(index)
+          ? (nullsFirst != descending
+          ? "compareNullsFirst"
+          : "compareNullsLast")
+          : "compare";
+    });
+  }
+
+  private Expression generateComparator(RelCollation collation,
+      Function1<RelFieldCollation, String> compareMethodNameFunction) {
     // int c;
     // c = Utilities.compare(v0, v1);
     // if (c != 0) return c; // or -c if descending
@@ -437,9 +455,6 @@ public class PhysTypeImpl implements PhysType {
       default:
         break;
       }
-      final boolean nullsFirst =
-          fieldCollation.nullDirection
-              == RelFieldCollation.NullDirection.FIRST;
       final boolean descending =
           fieldCollation.getDirection()
               == RelFieldCollation.Direction.DESCENDING;
@@ -449,11 +464,7 @@ public class PhysTypeImpl implements PhysType {
                   parameterC,
                   Expressions.call(
                       Utilities.class,
-                      fieldNullable(index)
-                          ? (nullsFirst != descending
-                          ? "compareNullsFirst"
-                          : "compareNullsLast")
-                          : "compare",
+                      compareMethodNameFunction.apply(fieldCollation),
                       Expressions.list(
                           arg0,
                           arg1)
@@ -511,6 +522,17 @@ public class PhysTypeImpl implements PhysType {
         memberDeclarations);
   }
 
+  @Override public Expression generateMergeJoinComparator(RelCollation 
collation) {
+    return this.generateComparator(collation, fieldCollation -> {
+      // merge join keys must be sorted in ascending order, nulls last
+      assert fieldCollation.nullDirection == 
RelFieldCollation.NullDirection.LAST;
+      assert fieldCollation.getDirection() == 
RelFieldCollation.Direction.ASCENDING;
+      return fieldNullable(fieldCollation.getFieldIndex())
+          ? "compareNullsLastForMergeJoin"
+          : "compare";
+    });
+  }
+
   @Override public RelDataType getRowType() {
     return rowType;
   }
@@ -616,65 +638,79 @@ public class PhysTypeImpl implements PhysType {
       for (int field : fields) {
         list.add(fieldReference(v1, field));
       }
-      switch (list.size()) {
-      case 2:
-        return Expressions.lambda(
-            Function1.class,
-            Expressions.call(
-                List.class,
-                null,
-                BuiltInMethod.LIST2.method,
-                list),
-            v1);
-      case 3:
-        return Expressions.lambda(
-            Function1.class,
-            Expressions.call(
-                List.class,
-                null,
-                BuiltInMethod.LIST3.method,
-                list),
-            v1);
-      case 4:
-        return Expressions.lambda(
-            Function1.class,
-            Expressions.call(
-                List.class,
-                null,
-                BuiltInMethod.LIST4.method,
-                list),
-            v1);
-      case 5:
-        return Expressions.lambda(
-            Function1.class,
-            Expressions.call(
-                List.class,
-                null,
-                BuiltInMethod.LIST5.method,
-                list),
-            v1);
-      case 6:
-        return Expressions.lambda(
-            Function1.class,
-            Expressions.call(
-                List.class,
-                null,
-                BuiltInMethod.LIST6.method,
-                list),
-            v1);
-      default:
-        return Expressions.lambda(
-            Function1.class,
-            Expressions.call(
-                List.class,
-                null,
-                BuiltInMethod.LIST_N.method,
-                Expressions.newArrayInit(
-                    Comparable.class,
-                    list)),
-            v1);
-      }
+      return Expressions.lambda(Function1.class, getListExpression(list), v1);
+    }
+  }
+
+  private static Expression 
getListExpression(Expressions.FluentList<Expression> list) {
+    assert list.size() >= 2;
+
+    switch (list.size()) {
+    case 2:
+      return Expressions.call(
+              List.class,
+              null,
+              BuiltInMethod.LIST2.method,
+              list);
+    case 3:
+      return Expressions.call(
+              List.class,
+              null,
+              BuiltInMethod.LIST3.method,
+              list);
+    case 4:
+      return Expressions.call(
+              List.class,
+              null,
+              BuiltInMethod.LIST4.method,
+              list);
+    case 5:
+      return Expressions.call(
+              List.class,
+              null,
+              BuiltInMethod.LIST5.method,
+              list);
+    case 6:
+      return Expressions.call(
+              List.class,
+              null,
+              BuiltInMethod.LIST6.method,
+              list);
+    default:
+      return Expressions.call(
+              List.class,
+              null,
+              BuiltInMethod.LIST_N.method,
+              Expressions.newArrayInit(Comparable.class, list));
+    }
+  }
+
+  @Override public Expression generateAccessorWithoutNulls(List<Integer> 
fields) {
+    if (fields.size() < 2) {
+      return generateAccessor(fields);
+    }
+
+    ParameterExpression v1 = Expressions.parameter(javaRowClass, "v1");
+    Expressions.FluentList<Expression> list = Expressions.list();
+    for (int field : fields) {
+      list.add(fieldReference(v1, field));
+    }
+
+    // (v1.<field0> == null)
+    //   ? null
+    //   : (v1.<field1> == null)
+    //     ? null;
+    //     : ...
+    //         : FlatLists.of(...);
+    Expression exp = getListExpression(list);
+    for (int i = list.size() - 1; i >= 0; i--) {
+      exp =
+          Expressions.condition(
+              Expressions.equal(list.get(i), Expressions.constant(null)),
+              Expressions.constant(null),
+              exp);
     }
+    return Expressions.lambda(Function1.class, exp, v1);
   }
 
   @Override public Expression fieldReference(
diff --git a/core/src/main/java/org/apache/calcite/runtime/Utilities.java 
b/core/src/main/java/org/apache/calcite/runtime/Utilities.java
index 458a07c898..3409af7224 100644
--- a/core/src/main/java/org/apache/calcite/runtime/Utilities.java
+++ b/core/src/main/java/org/apache/calcite/runtime/Utilities.java
@@ -16,6 +16,8 @@
  */
 package org.apache.calcite.runtime;
 
+import org.apache.calcite.linq4j.EnumerableDefaults;
+
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 import java.text.Collator;
@@ -250,6 +252,15 @@ public class Utilities {
                 : FlatLists.ComparableListImpl.compare(v0, v1);
   }
 
+  public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, 
@Nullable Comparable v1) {
+    return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1);
+  }
+
+  public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, 
@Nullable Comparable v1,
+      Comparator comparator) {
+    return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1, comparator);
+  }
+
   /** Creates a pattern builder. */
   public static Pattern.PatternBuilder patternBuilder() {
     return Pattern.builder();
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java 
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 6d7eedd10f..19c5228e09 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -205,7 +205,7 @@ public enum BuiltInMethod {
       List.class, int.class, Consumer.class),
   MERGE_JOIN(EnumerableDefaults.class, "mergeJoin", Enumerable.class,
       Enumerable.class, Function1.class, Function1.class, Predicate2.class, 
Function2.class,
-      JoinType.class, Comparator.class),
+      JoinType.class, Comparator.class, EqualityComparer.class),
   SLICE0(Enumerables.class, "slice0", Enumerable.class),
   SEMI_JOIN(EnumerableDefaults.class, "semiJoin", Enumerable.class,
       Enumerable.class, Function1.class, Function1.class,
diff --git a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java 
b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java
index 13970b53a7..0010482f95 100644
--- a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java
+++ b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java
@@ -414,7 +414,7 @@ class EnumerablesTest {
             e1 -> e1.name,
             e2 -> e2.name,
             (e1, e2) -> e1.deptno < e2.deptno,
-            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
         hasToString("["
             + "Emp(1, Fred)-Emp(2, Fred), "
             + "Emp(1, Fred)-Emp(3, Fred), "
@@ -430,7 +430,7 @@ class EnumerablesTest {
             e2 -> e2.name,
             e1 -> e1.name,
             (e2, e1) -> e2.deptno > e1.deptno,
-            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
         hasToString("["
             + "Emp(2, Fred)-Emp(1, Fred), "
             + "Emp(3, Fred)-Emp(1, Fred), "
@@ -446,7 +446,7 @@ class EnumerablesTest {
             e1 -> e1.name,
             e2 -> e2.name,
             (e1, e2) -> e1.deptno == e2.deptno * 2,
-            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
         hasToString("[]"));
 
     assertThat(
@@ -456,7 +456,7 @@ class EnumerablesTest {
             e2 -> e2.name,
             e1 -> e1.name,
             (e2, e1) -> e2.deptno == e1.deptno * 2,
-            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
         hasToString("[Emp(2, Fred)-Emp(1, Fred)]"));
 
     assertThat(
@@ -466,7 +466,7 @@ class EnumerablesTest {
             e2 -> e2.name,
             e1 -> e1.name,
             (e2, e1) -> e2.deptno == e1.deptno + 2,
-            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+            (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
         hasToString("[Emp(3, Fred)-Emp(1, Fred), Emp(5, Joe)-Emp(3, Joe)]"));
   }
 
@@ -493,7 +493,7 @@ class EnumerablesTest {
             null,
             (v0, v1) -> v0,
             JoinType.SEMI,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Dept(10, Marketing),"
             + " Dept(20, Sales),"
             + " Dept(30, Research)]"));
@@ -522,7 +522,7 @@ class EnumerablesTest {
             (d, e) -> e.name.contains("a"),
             (v0, v1) -> v0,
             JoinType.SEMI,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Dept(20, Sales)]"));
   }
 
@@ -550,7 +550,7 @@ class EnumerablesTest {
             (e, d) -> e.name.startsWith("T"),
             (v0, v1) -> v0,
             JoinType.SEMI,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Emp(30, Theodore)]"));
   }
 
@@ -578,7 +578,7 @@ class EnumerablesTest {
             null,
             (v0, v1) -> v0,
             JoinType.ANTI,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Dept(25, HR), Dept(40, Development)]"));
   }
 
@@ -605,7 +605,7 @@ class EnumerablesTest {
             (d, e) -> e.name.startsWith("F") || e.name.startsWith("S"),
             (v0, v1) -> v0,
             JoinType.ANTI,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Dept(25, HR), Dept(30, Research), "
             + "Dept(40, Development)]"));
   }
@@ -634,7 +634,7 @@ class EnumerablesTest {
             (e, d) -> d.deptno < 30,
             (v0, v1) -> v0,
             JoinType.ANTI,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Emp(30, Fred), Emp(20, Sebastian), Emp(20, Zoey)]"));
   }
 
@@ -661,7 +661,7 @@ class EnumerablesTest {
             null,
             (v0, v1) -> v0 + "-" + v1,
             JoinType.LEFT,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Dept(10, Marketing)-Emp(10, Fred),"
             + " Dept(20, Sales)-Emp(20, Theodore),"
             + " Dept(20, Sales)-Emp(20, Sebastian),"
@@ -694,7 +694,7 @@ class EnumerablesTest {
             (d, e) -> e.name.contains("a"),
             (v0, v1) -> v0 + "-" + v1,
             JoinType.LEFT,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Dept(10, Marketing)-null,"
             + " Dept(20, Sales)-Emp(20, Sebastian),"
             + " Dept(25, HR)-null,"
@@ -726,7 +726,7 @@ class EnumerablesTest {
             (e, d) -> e.name.startsWith("T"),
             (v0, v1) -> v0 + "-" + v1,
             JoinType.LEFT,
-            null).toList(),
+            null, null).toList(),
         hasToString("[Emp(30, Fred)-null,"
             + " Emp(20, Sebastian)-null,"
             + " Emp(30, Theodore)-Dept(30, Theodore),"
diff --git 
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
 
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
index 1e8be1b01c..1207e405c2 100644
--- 
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
+++ 
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
@@ -56,24 +56,6 @@ class EnumerableHashJoinTest {
             "empid=150; name=Sebastian; dept=Sales");
   }
 
-  @Test void innerJoinWithPredicate() {
-    tester(false, new HrSchema())
-        .query(
-            "select e.empid, e.name, d.name as dept from emps e join depts d"
-                + " on e.deptno=d.deptno and e.empid<150 and e.empid>d.deptno")
-        .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], 
name=[$t2], "
-            + "dept=[$t4])\n"
-            + "  EnumerableHashJoin(condition=[AND(=($1, $3), >($0, $3))], 
joinType=[inner])\n"
-            + "    EnumerableCalc(expr#0..4=[{inputs}], expr#5=[150], 
expr#6=[<($t0, $t5)], "
-            + "proj#0..2=[{exprs}], $condition=[$t6])\n"
-            + "      EnumerableTableScan(table=[[s, emps]])\n"
-            + "    EnumerableCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}])\n"
-            + "      EnumerableTableScan(table=[[s, depts]])\n")
-        .returnsUnordered(
-            "empid=100; name=Bill; dept=Sales",
-            "empid=110; name=Theodore; dept=Sales");
-  }
-
   @Test void leftOuterJoin() {
     tester(false, new HrSchema())
         .query(
@@ -201,6 +183,124 @@ class EnumerableHashJoinTest {
             "name=Sebastian; salary=7000.0");
   }
 
+  @Test void innerJoinWithPredicate() {
+    tester(false, new HrSchema())
+        .query(
+            "select e.empid, e.name, d.name as dept from emps e join depts d"
+                + " on e.deptno=d.deptno and e.empid<150 and e.empid>d.deptno")
+        .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], 
name=[$t2], "
+            + "dept=[$t4])\n"
+            + "  EnumerableHashJoin(condition=[AND(=($1, $3), >($0, $3))], 
joinType=[inner])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], expr#5=[150], 
expr#6=[<($t0, $t5)], "
+            + "proj#0..2=[{exprs}], $condition=[$t6])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n"
+            + "    EnumerableCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}])\n"
+            + "      EnumerableTableScan(table=[[s, depts]])\n")
+        .returnsUnordered(
+            "empid=100; name=Bill; dept=Sales",
+            "empid=110; name=Theodore; dept=Sales");
+  }
+
+  @Test void innerJoinWithCompositeKeyAndNullValues() {
+    tester(false, new HrSchema())
+        .query(
+            "select e1.empid from emps e1 join emps e2 "
+                + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+            planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+        .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0])\n"
+            + "  EnumerableHashJoin(condition=[AND(=($1, $3), =($2, $4))], 
joinType=[inner])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], deptno=[$t1], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n")
+        .returnsUnordered(
+            "empid=100",
+            "empid=110",
+            "empid=200");
+  }
+
+  @Test void leftOuterJoinWithCompositeKeyAndNullValues() {
+    tester(false, new HrSchema())
+        .query(
+            "select e1.empid, e2.empid from emps e1 left outer join emps e2 "
+                + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+            planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+        .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0], 
empid0=[$t3])\n"
+            + "  EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))], 
joinType=[left])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n")
+        .returnsUnordered(
+            "empid=100; empid=100",
+            "empid=110; empid=110",
+            "empid=150; empid=null",
+            "empid=200; empid=200");
+  }
+
+  @Test void rightOuterJoinWithCompositeKeyAndNullValues() {
+    tester(false, new HrSchema())
+        .query(
+            "select e1.empid, e2.empid from emps e1 right outer join emps e2 "
+                + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+            planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+        .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0], 
empid0=[$t3])\n"
+            + "  EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))], 
joinType=[right])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n")
+        .returnsUnordered(
+            "empid=100; empid=100",
+            "empid=110; empid=110",
+            "empid=200; empid=200",
+            "empid=null; empid=150");
+  }
+
+  @Test void fullOuterJoinWithCompositeKeyAndNullValues() {
+    tester(false, new HrSchema())
+        .query(
+            "select e1.empid, e2.empid from emps e1 full outer join emps e2 "
+                + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+            planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+        .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0], 
empid0=[$t3])\n"
+            + "  EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))], 
joinType=[full])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n")
+        .returnsUnordered(
+            "empid=100; empid=100",
+            "empid=110; empid=110",
+            "empid=150; empid=null",
+            "empid=200; empid=200",
+            "empid=null; empid=150");
+  }
+
+  @Test void semiJoinWithCompositeKeyAndNullValues() {
+    tester(true, new HrSchema())
+        .query(
+            "select e1.empid from emps e1 where exists (select 1 from emps e2 "
+                + "where e1.deptno=e2.deptno and e1.commission=e2.commission)")
+        .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+          planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE);
+        })
+        .explainContains("EnumerableCalc(expr#0..2=[{inputs}], empid=[$t0])\n"
+            + "  EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $7))], 
joinType=[semi])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n"
+            + "    EnumerableCalc(expr#0..4=[{inputs}], expr#5=[IS NOT 
NULL($t4)], proj#0..4=[{exprs}], $condition=[$t5])\n"
+            + "      EnumerableTableScan(table=[[s, emps]])\n")
+        .returnsUnordered(
+            "empid=100",
+            "empid=110",
+            "empid=200");
+  }
+
   private CalciteAssert.AssertThat tester(boolean forceDecorrelate,
       Object schema) {
     return CalciteAssert.that()
diff --git 
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java 
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java
index 30f69f1034..a15af886d6 100644
--- 
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java
+++ 
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java
@@ -29,9 +29,11 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.test.CalciteAssert;
 import org.apache.calcite.test.schemata.hr.HierarchySchema;
 import org.apache.calcite.test.schemata.hr.HrSchema;
+import org.apache.calcite.test.schemata.hr.HrSchemaBig;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.Arrays;
 import java.util.function.Consumer;
 
 /**
@@ -218,18 +220,79 @@ class EnumerableJoinTest {
   /** Test case for
    * <a 
href="https://issues.apache.org/jira/browse/CALCITE-3846";>[CALCITE-3846]
    * EnumerableMergeJoin: wrong comparison of composite key with null 
values</a>. */
-  @Test void testMergeJoinWithCompositeKeyAndNullValues() {
-    tester(false, new HrSchema())
+  @Test void testMergeJoinInnerWithCompositeKeyAndNullValues() {
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        false,
+        JoinRelType.INNER,
+        "empid=110; empid0=110",
+        "empid=100; empid0=100",
+        "empid=200; empid0=200");
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        true,
+        JoinRelType.INNER,
+        "empid=48; empid0=48",
+        "empid=4; empid0=4",
+        "empid=4; empid0=8");
+  }
+
+  @Test void testMergeJoinLeftWithCompositeKeyAndNullValues() {
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        false,
+        JoinRelType.LEFT,
+        "empid=110; empid0=110",
+        "empid=100; empid0=100",
+        "empid=150; empid0=null",
+        "empid=200; empid0=200");
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        true,
+        JoinRelType.LEFT,
+        "empid=48; empid0=48",
+        "empid=47; empid0=null",
+        "empid=4; empid0=4");
+  }
+
+  @Test void testMergeJoinSemiWithCompositeKeyAndNullValues() {
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        false,
+        JoinRelType.SEMI,
+        "empid=110",
+        "empid=100",
+        "empid=200");
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        true,
+        JoinRelType.SEMI,
+        "empid=48",
+        "empid=4",
+        "empid=8");
+  }
+
+  @Test void testMergeJoinAntiWithCompositeKeyAndNullValues() {
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        false,
+        JoinRelType.ANTI,
+        "empid=150");
+    checkMergeJoinWithCompositeKeyAndNullValues(
+        true,
+        JoinRelType.ANTI,
+        "empid=47",
+        "empid=3",
+        "empid=7");
+  }
+
+  private void checkMergeJoinWithCompositeKeyAndNullValues(boolean bigSchema, 
JoinRelType joinType,
+      String... expected) {
+    CalciteAssert.AssertQuery checker =
+        tester(false, bigSchema ? new HrSchemaBig() : new HrSchema())
         .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
           planner.addRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE);
           planner.removeRule(EnumerableRules.ENUMERABLE_JOIN_RULE);
         })
         .withRel(builder -> builder
-            .scan("s", "emps")
-            .sort(builder.field("deptno"), builder.field("commission"))
-            .scan("s", "emps")
-            .sort(builder.field("deptno"), builder.field("commission"))
-            .join(JoinRelType.INNER,
+            .scan("s", "emps").as("e1")
+            .sort(builder.field("deptno"), builder.field("commission"), 
builder.field("empid"))
+            .scan("s", "emps").as("e2")
+            .sort(builder.field("deptno"), builder.field("commission"), 
builder.field("empid"))
+            .join(joinType,
                 builder.and(
                     builder.equals(
                         builder.field(2, 0, "deptno"),
@@ -237,22 +300,16 @@ class EnumerableJoinTest {
                     builder.equals(
                         builder.field(2, 0, "commission"),
                         builder.field(2, 1, "commission"))))
-            .project(
-                builder.field("empid"))
+            .project(joinType.projectsRight()
+                ? Arrays.asList(builder.field("e1", "empid"), 
builder.field("e2", "empid"))
+                : Arrays.asList(builder.field("e1", "empid")))
             .build())
-        .explainHookMatches("" // It is important that we have MergeJoin in 
the plan
-            + "EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0])\n"
-            + "  EnumerableMergeJoin(condition=[AND(=($1, $3), =($2, $4))], 
joinType=[inner])\n"
-            + "    EnumerableSort(sort0=[$1], sort1=[$2], dir0=[ASC], 
dir1=[ASC])\n"
-            + "      EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], 
commission=[$t4])\n"
-            + "        EnumerableTableScan(table=[[s, emps]])\n"
-            + "    EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], 
dir1=[ASC])\n"
-            + "      EnumerableCalc(expr#0..4=[{inputs}], deptno=[$t1], 
commission=[$t4])\n"
-            + "        EnumerableTableScan(table=[[s, emps]])\n")
-        .returnsUnordered("empid=100",
-            "empid=110",
-            "empid=150",
-            "empid=200");
+        .explainHookContains("EnumerableMergeJoin"); // We must have MergeJoin 
in the plan
+    if (bigSchema) {
+      checker.returnsStartingWith(expected);
+    } else {
+      checker.returnsOrdered(expected);
+    }
   }
 
   /** Test case for
diff --git 
a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index d623d8619d..851c0d40d3 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -1924,7 +1924,7 @@ public abstract class EnumerableDefaults {
         final Predicate1<TSource> predicate = v0 -> {
           TKey key = outerKeySelector.apply(v0);
           @SuppressWarnings("argument.type.incompatible")
-          Enumerable<TInner> innersOfKey = innerLookup.get().get(key);
+          Enumerable<TInner> innersOfKey = key == null ? null : 
innerLookup.get().get(key);
           if (innersOfKey == null) {
             return anti;
           }
@@ -1963,9 +1963,11 @@ public abstract class EnumerableDefaults {
                 ? inner.select(innerKeySelector).distinct()
                 : inner.select(innerKeySelector).distinct(comparer));
 
-        final Predicate1<TSource> predicate = anti
-            ? v0 -> !innerLookup.get().contains(outerKeySelector.apply(v0))
-            : v0 -> innerLookup.get().contains(outerKeySelector.apply(v0));
+        final Predicate1<TSource> predicate = v0 -> {
+          TKey key = outerKeySelector.apply(v0);
+          boolean found = key != null && innerLookup.get().contains(key);
+          return anti ? !found : found;
+        };
 
         return EnumerableDefaults.where(outer.enumerator(), predicate);
       }
@@ -2146,9 +2148,9 @@ public abstract class EnumerableDefaults {
 
   /**
    * Joins two inputs that are sorted on the key.
-   * Inputs must sorted in ascending order, nulls last.
+   * Inputs must be sorted in ascending order, nulls last.
    *
-   * @deprecated Use {@link #mergeJoin(Enumerable, Enumerable, Function1, 
Function1, Function2, JoinType, Comparator)}
+   * @deprecated Use {@link #mergeJoin(Enumerable, Enumerable, Function1, 
Function1, Predicate2, Function2, JoinType, Comparator, EqualityComparer)}
    */
   @Deprecated // to be removed before 2.0
   public static <TSource, TInner, TKey extends Comparable<TKey>, TResult> 
Enumerable<TResult>
@@ -2168,7 +2170,7 @@ public abstract class EnumerableDefaults {
           "not implemented, mergeJoin with generateNullsOnRight");
     }
     return mergeJoin(outer, inner, outerKeySelector, innerKeySelector, null, 
resultSelector,
-        JoinType.INNER, null);
+        JoinType.INNER, null, null);
   }
 
   /**
@@ -2191,7 +2193,7 @@ public abstract class EnumerableDefaults {
 
   /**
    * Joins two inputs that are sorted on the key.
-   * Inputs must sorted in ascending order, nulls last.
+   * Inputs must be sorted in ascending order, nulls last.
    */
   public static <TSource, TInner, TKey extends Comparable<TKey>, TResult> 
Enumerable<TResult>
       mergeJoin(final Enumerable<TSource> outer,
@@ -2202,7 +2204,7 @@ public abstract class EnumerableDefaults {
       final JoinType joinType,
       final Comparator<TKey> comparator) {
     return mergeJoin(outer, inner, outerKeySelector, innerKeySelector, null, 
resultSelector,
-        joinType, comparator);
+        joinType, comparator, null);
   }
 
   /**
@@ -2219,6 +2221,8 @@ public abstract class EnumerableDefaults {
    *                       types in the future (e.g. semi or anti joins).
    * @param comparator key comparator, possibly null (in which case {@link 
Comparable#compareTo}
    *                   will be used).
+   * @param equalityComparer key equality comparer, possibly null (in which 
case equals
+   *                         will be used), required to compare keys from the 
same input.
    *
    * <p>NOTE: The current API is experimental and subject to change without
    * notice.
@@ -2232,14 +2236,15 @@ public abstract class EnumerableDefaults {
       final @Nullable Predicate2<TSource, TInner> extraPredicate,
       final Function2<TSource, @Nullable TInner, TResult> resultSelector,
       final JoinType joinType,
-      final @Nullable Comparator<TKey> comparator) {
+      final @Nullable Comparator<TKey> comparator,
+      final @Nullable EqualityComparer<TKey> equalityComparer) {
     if (!isMergeJoinSupported(joinType)) {
       throw new UnsupportedOperationException("MergeJoin unsupported for join 
type " + joinType);
     }
     return new AbstractEnumerable<TResult>() {
       @Override public Enumerator<TResult> enumerator() {
         return new MergeJoinEnumerator<>(outer, inner, outerKeySelector, 
innerKeySelector,
-            extraPredicate, resultSelector, joinType, comparator);
+            extraPredicate, resultSelector, joinType, comparator, 
equalityComparer);
       }
     };
   }
@@ -4144,7 +4149,7 @@ public abstract class EnumerableDefaults {
   }
 
   /** Enumerator that performs a merge join on its sorted inputs.
-   * Inputs must sorted in ascending order, nulls last.
+   * Inputs must be sorted in ascending order, nulls last.
    *
    * @param <TResult> result type
    * @param <TSource> left input record type
@@ -4167,6 +4172,8 @@ public abstract class EnumerableDefaults {
     private final JoinType joinType;
     // key comparator, possibly null (Comparable#compareTo to be used in that 
case)
     private final @Nullable Comparator<TKey> comparator;
+    // key equality comparer, possibly null (equals to be used in that case)
+    private final @Nullable EqualityComparer<TKey> equalityComparer;
     private boolean done;
     private @Nullable Enumerator<TResult> results = null;
     // used for LEFT/ANTI join: if right input is over, all remaining elements 
from left are results
@@ -4181,7 +4188,8 @@ public abstract class EnumerableDefaults {
         @Nullable Predicate2<TSource, TInner> extraPredicate,
         Function2<TSource, @Nullable TInner, TResult> resultSelector,
         JoinType joinType,
-        @Nullable Comparator<TKey> comparator) {
+        @Nullable Comparator<TKey> comparator,
+        @Nullable EqualityComparer<TKey> equalityComparer) {
       this.leftEnumerable = leftEnumerable;
       this.rightEnumerable = rightEnumerable;
       this.outerKeySelector = outerKeySelector;
@@ -4190,6 +4198,7 @@ public abstract class EnumerableDefaults {
       this.resultSelector = resultSelector;
       this.joinType = joinType;
       this.comparator = comparator;
+      this.equalityComparer = equalityComparer;
       start();
     }
 
@@ -4258,15 +4267,18 @@ public abstract class EnumerableDefaults {
       results = Linq4j.emptyEnumerator();
     }
 
+    /** Method to compare keys from left and right input (nulls must not be 
considered equal). */
     private int compare(TKey key1, TKey key2) {
-      return comparator != null ? comparator.compare(key1, key2) : 
compareNullsLast(key1, key2);
+      return comparator != null
+          ? comparator.compare(key1, key2)
+          : compareNullsLastForMergeJoin(key1, key2);
     }
 
-    private int compareNullsLast(TKey v0, TKey v1) {
-      return v0 == v1 ? 0
-          : v0 == null ? 1
-              : v1 == null ? -1
-                  : v0.compareTo(v1);
+    /** Method to compare keys from the same input (nulls must be considered 
equal). */
+    private boolean compareEquals(TKey key1, TKey key2) {
+      return equalityComparer != null
+          ? equalityComparer.equal(key1, key2)
+          : Objects.equals(key1, key2);
     }
 
     /** Moves to the next key that is present in both sides. Populates
@@ -4291,7 +4303,14 @@ public abstract class EnumerableDefaults {
             done = true;
             return false;
           }
-          int c = compare(leftKey, rightKey);
+          int c;
+          try {
+            c = compare(leftKey, rightKey);
+          } catch (BothValuesAreNullException e) {
+            // consider the first (left) value as "bigger", to advance on the 
right value
+            // and continue with the algorithm
+            c = 1;
+          }
           if (c == 0) {
             break;
           }
@@ -4390,13 +4409,7 @@ public abstract class EnumerableDefaults {
           // if we reach a null key, we are done (except LEFT join, that needs 
to process LHS fully)
           break;
         }
-        int c = compare(leftKey, leftKey2);
-        if (c != 0) {
-          if (c > 0) {
-            throw new IllegalStateException(
-                "mergeJoin assumes inputs sorted in ascending order, " + 
"however '"
-                    + leftKey + "' is greater than '" + leftKey2 + "'");
-          }
+        if (!compareEquals(leftKey, leftKey2)) {
           return true;
         }
         lefts.add(left);
@@ -4423,13 +4436,7 @@ public abstract class EnumerableDefaults {
           // if we reach a null key, we are done
           break;
         }
-        int c = compare(rightKey, rightKey2);
-        if (c != 0) {
-          if (c > 0) {
-            throw new IllegalStateException(
-                "mergeJoin assumes input sorted in ascending order, " + 
"however '"
-                    + rightKey + "' is greater than '" + rightKey2 + "'");
-          }
+        if (!compareEquals(rightKey, rightKey2)) {
           return true;
         }
         rights.add(right);
@@ -4495,6 +4502,36 @@ public abstract class EnumerableDefaults {
     }
   }
 
+  /**
+   * Exception used for control flow (it does not populate the stack trace to 
be more efficient)
+   * to signal that both values are <code>null</code>, so that the caller 
method (i.e. MergeJoin
+   * algorithm) can act accordingly.
+   */
+  private static class BothValuesAreNullException extends RuntimeException {
+    @Override public synchronized Throwable fillInStackTrace() {
+      return this;
+    }
+  }
+
+  public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, 
@Nullable Comparable v1) {
+    return compareNullsLastForMergeJoin(v0, v1, null);
+  }
+
+  public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, 
@Nullable Comparable v1,
+      @Nullable Comparator comparator) {
+    // Special code for mergeJoin algorithm: in case of two null values, they 
must not be
+    // considered as equal (otherwise the join would return incorrect results)
+    if (v0 == null && v1 == null) {
+      throw new BothValuesAreNullException();
+    }
+
+    //noinspection unchecked
+    return v0 == v1 ? 0
+        : v0 == null ? 1
+            : v1 == null ? -1
+                : comparator == null ? v0.compareTo(v1) : 
comparator.compare(v0, v1);
+  }
+
   /** Enumerates the elements of a cartesian product of two inputs.
    *
    * @param <TResult> result type


Reply via email to