This is an automated email from the ASF dual-hosted git repository.
rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 3aee0b86aa [CALCITE-5732] EnumerableHashJoin and EnumerableMergeJoin
on composite key return rows matching condition 'null = null'
3aee0b86aa is described below
commit 3aee0b86aa23476cbdecc75ad5d43b936a6fff7b
Author: rubenada <[email protected]>
AuthorDate: Wed Jul 12 14:45:05 2023 +0100
[CALCITE-5732] EnumerableHashJoin and EnumerableMergeJoin on composite key
return rows matching condition 'null = null'
---
.../adapter/enumerable/EnumerableHashJoin.java | 8 +-
.../adapter/enumerable/EnumerableMergeJoin.java | 7 +-
.../calcite/adapter/enumerable/PhysType.java | 26 +++-
.../calcite/adapter/enumerable/PhysTypeImpl.java | 168 +++++++++++++--------
.../java/org/apache/calcite/runtime/Utilities.java | 11 ++
.../org/apache/calcite/util/BuiltInMethod.java | 2 +-
.../apache/calcite/runtime/EnumerablesTest.java | 28 ++--
.../test/enumerable/EnumerableHashJoinTest.java | 136 ++++++++++++++---
.../test/enumerable/EnumerableJoinTest.java | 101 ++++++++++---
.../apache/calcite/linq4j/EnumerableDefaults.java | 105 ++++++++-----
10 files changed, 430 insertions(+), 162 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
index 259fb68503..cfe44da75d 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java
@@ -218,8 +218,8 @@ public class EnumerableHashJoin extends Join implements
EnumerableRel {
Expressions.list(
leftExpression,
rightExpression,
- leftResult.physType.generateAccessor(joinInfo.leftKeys),
- rightResult.physType.generateAccessor(joinInfo.rightKeys),
+
leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys),
+
rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys),
Util.first(keyPhysType.comparer(),
Expressions.constant(null)),
predicate)))
@@ -264,8 +264,8 @@ public class EnumerableHashJoin extends Join implements
EnumerableRel {
BuiltInMethod.HASH_JOIN.method,
Expressions.list(
rightExpression,
- leftResult.physType.generateAccessor(joinInfo.leftKeys),
- rightResult.physType.generateAccessor(joinInfo.rightKeys),
+
leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys),
+
rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys),
EnumUtils.joinSelector(joinType,
physType,
ImmutableList.of(
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
index d3f45d8d31..3f42e065c8 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java
@@ -492,7 +492,7 @@ public class EnumerableMergeJoin extends Join implements
EnumerableRel {
RelFieldCollation.NullDirection.LAST));
}
final RelCollation collation = RelCollations.of(fieldCollations);
- final Expression comparator =
leftKeyPhysType.generateComparator(collation);
+ final Expression comparator =
leftKeyPhysType.generateMergeJoinComparator(collation);
return implementor.result(
physType,
@@ -512,6 +512,9 @@ public class EnumerableMergeJoin extends Join implements
EnumerableRel {
ImmutableList.of(
leftResult.physType, rightResult.physType)),
Expressions.constant(EnumUtils.toLinq4jJoinType(joinType)),
- comparator))).toBlock());
+ comparator,
+ Util.first(
+ leftKeyPhysType.comparer(),
+ Expressions.constant(null))))).toBlock());
}
}
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java
index 4447e91dfe..8d4eeb176c 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java
@@ -112,11 +112,28 @@ public interface PhysType {
* public Object[] apply(Employee v1) {
* return FlatLists.of(v1.<fieldN>, v1.<fieldM>);
* }
- * }
* }</pre></blockquote>
*/
Expression generateAccessor(List<Integer> fields);
+ /** Similar to {@link #generateAccessor(List)}, but if one of the fields is
<code>null</code>,
+ * it will return <code>null</code>.
+ *
+ * <p>For example:
+ *
+ * <blockquote><pre>
+ * new Function1<Employee, Object[]> {
+ * public Object[] apply(Employee v1) {
+ * return v1.<fieldN> == null
+ * ? null
+ * : v1.<fieldM> == null
+ * ? null
+ * : FlatLists.of(v1.<fieldN>, v1.<fieldM>);
+ * }
+ * }</pre></blockquote>
+ */
+ Expression generateAccessorWithoutNulls(List<Integer> fields);
+
/** Generates a selector for the given fields from an expression, with the
* default row format. */
Expression generateSelector(
@@ -181,6 +198,13 @@ public interface PhysType {
Expression generateComparator(
RelCollation collation);
+ /** Similar to {@link #generateComparator(RelCollation)}, but with some
specificities for
+ * MergeJoin algorithm: it will not consider two <code>null</code> values as
equal.
+ *
+ * @see
org.apache.calcite.linq4j.EnumerableDefaults#compareNullsLastForMergeJoin
+ */
+ Expression generateMergeJoinComparator(RelCollation collation);
+
/** Returns a expression that yields a comparer, or null if this type
* is comparable. */
@Nullable Expression comparer();
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java
index 55a44b2a29..ab51dd3e35 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java
@@ -407,6 +407,24 @@ public class PhysTypeImpl implements PhysType {
}
@Override public Expression generateComparator(RelCollation collation) {
+ return this.generateComparator(collation, fieldCollation -> {
+ final int index = fieldCollation.getFieldIndex();
+ final boolean nullsFirst =
+ fieldCollation.nullDirection
+ == RelFieldCollation.NullDirection.FIRST;
+ final boolean descending =
+ fieldCollation.getDirection()
+ == RelFieldCollation.Direction.DESCENDING;
+ return fieldNullable(index)
+ ? (nullsFirst != descending
+ ? "compareNullsFirst"
+ : "compareNullsLast")
+ : "compare";
+ });
+ }
+
+ private Expression generateComparator(RelCollation collation,
+ Function1<RelFieldCollation, String> compareMethodNameFunction) {
// int c;
// c = Utilities.compare(v0, v1);
// if (c != 0) return c; // or -c if descending
@@ -437,9 +455,6 @@ public class PhysTypeImpl implements PhysType {
default:
break;
}
- final boolean nullsFirst =
- fieldCollation.nullDirection
- == RelFieldCollation.NullDirection.FIRST;
final boolean descending =
fieldCollation.getDirection()
== RelFieldCollation.Direction.DESCENDING;
@@ -449,11 +464,7 @@ public class PhysTypeImpl implements PhysType {
parameterC,
Expressions.call(
Utilities.class,
- fieldNullable(index)
- ? (nullsFirst != descending
- ? "compareNullsFirst"
- : "compareNullsLast")
- : "compare",
+ compareMethodNameFunction.apply(fieldCollation),
Expressions.list(
arg0,
arg1)
@@ -511,6 +522,17 @@ public class PhysTypeImpl implements PhysType {
memberDeclarations);
}
+ @Override public Expression generateMergeJoinComparator(RelCollation
collation) {
+ return this.generateComparator(collation, fieldCollation -> {
+ // merge join keys must be sorted in ascending order, nulls last
+ assert fieldCollation.nullDirection ==
RelFieldCollation.NullDirection.LAST;
+ assert fieldCollation.getDirection() ==
RelFieldCollation.Direction.ASCENDING;
+ return fieldNullable(fieldCollation.getFieldIndex())
+ ? "compareNullsLastForMergeJoin"
+ : "compare";
+ });
+ }
+
@Override public RelDataType getRowType() {
return rowType;
}
@@ -616,65 +638,79 @@ public class PhysTypeImpl implements PhysType {
for (int field : fields) {
list.add(fieldReference(v1, field));
}
- switch (list.size()) {
- case 2:
- return Expressions.lambda(
- Function1.class,
- Expressions.call(
- List.class,
- null,
- BuiltInMethod.LIST2.method,
- list),
- v1);
- case 3:
- return Expressions.lambda(
- Function1.class,
- Expressions.call(
- List.class,
- null,
- BuiltInMethod.LIST3.method,
- list),
- v1);
- case 4:
- return Expressions.lambda(
- Function1.class,
- Expressions.call(
- List.class,
- null,
- BuiltInMethod.LIST4.method,
- list),
- v1);
- case 5:
- return Expressions.lambda(
- Function1.class,
- Expressions.call(
- List.class,
- null,
- BuiltInMethod.LIST5.method,
- list),
- v1);
- case 6:
- return Expressions.lambda(
- Function1.class,
- Expressions.call(
- List.class,
- null,
- BuiltInMethod.LIST6.method,
- list),
- v1);
- default:
- return Expressions.lambda(
- Function1.class,
- Expressions.call(
- List.class,
- null,
- BuiltInMethod.LIST_N.method,
- Expressions.newArrayInit(
- Comparable.class,
- list)),
- v1);
- }
+ return Expressions.lambda(Function1.class, getListExpression(list), v1);
+ }
+ }
+
+ private static Expression
getListExpression(Expressions.FluentList<Expression> list) {
+ assert list.size() >= 2;
+
+ switch (list.size()) {
+ case 2:
+ return Expressions.call(
+ List.class,
+ null,
+ BuiltInMethod.LIST2.method,
+ list);
+ case 3:
+ return Expressions.call(
+ List.class,
+ null,
+ BuiltInMethod.LIST3.method,
+ list);
+ case 4:
+ return Expressions.call(
+ List.class,
+ null,
+ BuiltInMethod.LIST4.method,
+ list);
+ case 5:
+ return Expressions.call(
+ List.class,
+ null,
+ BuiltInMethod.LIST5.method,
+ list);
+ case 6:
+ return Expressions.call(
+ List.class,
+ null,
+ BuiltInMethod.LIST6.method,
+ list);
+ default:
+ return Expressions.call(
+ List.class,
+ null,
+ BuiltInMethod.LIST_N.method,
+ Expressions.newArrayInit(Comparable.class, list));
+ }
+ }
+
+ @Override public Expression generateAccessorWithoutNulls(List<Integer>
fields) {
+ if (fields.size() < 2) {
+ return generateAccessor(fields);
+ }
+
+ ParameterExpression v1 = Expressions.parameter(javaRowClass, "v1");
+ Expressions.FluentList<Expression> list = Expressions.list();
+ for (int field : fields) {
+ list.add(fieldReference(v1, field));
+ }
+
+ // (v1.<field0> == null)
+ // ? null
+ // : (v1.<field1> == null)
+ // ? null;
+ // : ...
+ // : FlatLists.of(...);
+ Expression exp = getListExpression(list);
+ for (int i = list.size() - 1; i >= 0; i--) {
+ exp =
+ Expressions.condition(
+ Expressions.equal(list.get(i), Expressions.constant(null)),
+ Expressions.constant(null),
+ exp);
}
+ return Expressions.lambda(Function1.class, exp, v1);
}
@Override public Expression fieldReference(
diff --git a/core/src/main/java/org/apache/calcite/runtime/Utilities.java
b/core/src/main/java/org/apache/calcite/runtime/Utilities.java
index 458a07c898..3409af7224 100644
--- a/core/src/main/java/org/apache/calcite/runtime/Utilities.java
+++ b/core/src/main/java/org/apache/calcite/runtime/Utilities.java
@@ -16,6 +16,8 @@
*/
package org.apache.calcite.runtime;
+import org.apache.calcite.linq4j.EnumerableDefaults;
+
import org.checkerframework.checker.nullness.qual.Nullable;
import java.text.Collator;
@@ -250,6 +252,15 @@ public class Utilities {
: FlatLists.ComparableListImpl.compare(v0, v1);
}
+ public static int compareNullsLastForMergeJoin(@Nullable Comparable v0,
@Nullable Comparable v1) {
+ return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1);
+ }
+
+ public static int compareNullsLastForMergeJoin(@Nullable Comparable v0,
@Nullable Comparable v1,
+ Comparator comparator) {
+ return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1, comparator);
+ }
+
/** Creates a pattern builder. */
public static Pattern.PatternBuilder patternBuilder() {
return Pattern.builder();
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 6d7eedd10f..19c5228e09 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -205,7 +205,7 @@ public enum BuiltInMethod {
List.class, int.class, Consumer.class),
MERGE_JOIN(EnumerableDefaults.class, "mergeJoin", Enumerable.class,
Enumerable.class, Function1.class, Function1.class, Predicate2.class,
Function2.class,
- JoinType.class, Comparator.class),
+ JoinType.class, Comparator.class, EqualityComparer.class),
SLICE0(Enumerables.class, "slice0", Enumerable.class),
SEMI_JOIN(EnumerableDefaults.class, "semiJoin", Enumerable.class,
Enumerable.class, Function1.class, Function1.class,
diff --git a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java
b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java
index 13970b53a7..0010482f95 100644
--- a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java
+++ b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java
@@ -414,7 +414,7 @@ class EnumerablesTest {
e1 -> e1.name,
e2 -> e2.name,
(e1, e2) -> e1.deptno < e2.deptno,
- (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+ (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
hasToString("["
+ "Emp(1, Fred)-Emp(2, Fred), "
+ "Emp(1, Fred)-Emp(3, Fred), "
@@ -430,7 +430,7 @@ class EnumerablesTest {
e2 -> e2.name,
e1 -> e1.name,
(e2, e1) -> e2.deptno > e1.deptno,
- (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+ (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
hasToString("["
+ "Emp(2, Fred)-Emp(1, Fred), "
+ "Emp(3, Fred)-Emp(1, Fred), "
@@ -446,7 +446,7 @@ class EnumerablesTest {
e1 -> e1.name,
e2 -> e2.name,
(e1, e2) -> e1.deptno == e2.deptno * 2,
- (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+ (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
hasToString("[]"));
assertThat(
@@ -456,7 +456,7 @@ class EnumerablesTest {
e2 -> e2.name,
e1 -> e1.name,
(e2, e1) -> e2.deptno == e1.deptno * 2,
- (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+ (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
hasToString("[Emp(2, Fred)-Emp(1, Fred)]"));
assertThat(
@@ -466,7 +466,7 @@ class EnumerablesTest {
e2 -> e2.name,
e1 -> e1.name,
(e2, e1) -> e2.deptno == e1.deptno + 2,
- (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(),
+ (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(),
hasToString("[Emp(3, Fred)-Emp(1, Fred), Emp(5, Joe)-Emp(3, Joe)]"));
}
@@ -493,7 +493,7 @@ class EnumerablesTest {
null,
(v0, v1) -> v0,
JoinType.SEMI,
- null).toList(),
+ null, null).toList(),
hasToString("[Dept(10, Marketing),"
+ " Dept(20, Sales),"
+ " Dept(30, Research)]"));
@@ -522,7 +522,7 @@ class EnumerablesTest {
(d, e) -> e.name.contains("a"),
(v0, v1) -> v0,
JoinType.SEMI,
- null).toList(),
+ null, null).toList(),
hasToString("[Dept(20, Sales)]"));
}
@@ -550,7 +550,7 @@ class EnumerablesTest {
(e, d) -> e.name.startsWith("T"),
(v0, v1) -> v0,
JoinType.SEMI,
- null).toList(),
+ null, null).toList(),
hasToString("[Emp(30, Theodore)]"));
}
@@ -578,7 +578,7 @@ class EnumerablesTest {
null,
(v0, v1) -> v0,
JoinType.ANTI,
- null).toList(),
+ null, null).toList(),
hasToString("[Dept(25, HR), Dept(40, Development)]"));
}
@@ -605,7 +605,7 @@ class EnumerablesTest {
(d, e) -> e.name.startsWith("F") || e.name.startsWith("S"),
(v0, v1) -> v0,
JoinType.ANTI,
- null).toList(),
+ null, null).toList(),
hasToString("[Dept(25, HR), Dept(30, Research), "
+ "Dept(40, Development)]"));
}
@@ -634,7 +634,7 @@ class EnumerablesTest {
(e, d) -> d.deptno < 30,
(v0, v1) -> v0,
JoinType.ANTI,
- null).toList(),
+ null, null).toList(),
hasToString("[Emp(30, Fred), Emp(20, Sebastian), Emp(20, Zoey)]"));
}
@@ -661,7 +661,7 @@ class EnumerablesTest {
null,
(v0, v1) -> v0 + "-" + v1,
JoinType.LEFT,
- null).toList(),
+ null, null).toList(),
hasToString("[Dept(10, Marketing)-Emp(10, Fred),"
+ " Dept(20, Sales)-Emp(20, Theodore),"
+ " Dept(20, Sales)-Emp(20, Sebastian),"
@@ -694,7 +694,7 @@ class EnumerablesTest {
(d, e) -> e.name.contains("a"),
(v0, v1) -> v0 + "-" + v1,
JoinType.LEFT,
- null).toList(),
+ null, null).toList(),
hasToString("[Dept(10, Marketing)-null,"
+ " Dept(20, Sales)-Emp(20, Sebastian),"
+ " Dept(25, HR)-null,"
@@ -726,7 +726,7 @@ class EnumerablesTest {
(e, d) -> e.name.startsWith("T"),
(v0, v1) -> v0 + "-" + v1,
JoinType.LEFT,
- null).toList(),
+ null, null).toList(),
hasToString("[Emp(30, Fred)-null,"
+ " Emp(20, Sebastian)-null,"
+ " Emp(30, Theodore)-Dept(30, Theodore),"
diff --git
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
index 1e8be1b01c..1207e405c2 100644
---
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
+++
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java
@@ -56,24 +56,6 @@ class EnumerableHashJoinTest {
"empid=150; name=Sebastian; dept=Sales");
}
- @Test void innerJoinWithPredicate() {
- tester(false, new HrSchema())
- .query(
- "select e.empid, e.name, d.name as dept from emps e join depts d"
- + " on e.deptno=d.deptno and e.empid<150 and e.empid>d.deptno")
- .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0],
name=[$t2], "
- + "dept=[$t4])\n"
- + " EnumerableHashJoin(condition=[AND(=($1, $3), >($0, $3))],
joinType=[inner])\n"
- + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[150],
expr#6=[<($t0, $t5)], "
- + "proj#0..2=[{exprs}], $condition=[$t6])\n"
- + " EnumerableTableScan(table=[[s, emps]])\n"
- + " EnumerableCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}])\n"
- + " EnumerableTableScan(table=[[s, depts]])\n")
- .returnsUnordered(
- "empid=100; name=Bill; dept=Sales",
- "empid=110; name=Theodore; dept=Sales");
- }
-
@Test void leftOuterJoin() {
tester(false, new HrSchema())
.query(
@@ -201,6 +183,124 @@ class EnumerableHashJoinTest {
"name=Sebastian; salary=7000.0");
}
+ @Test void innerJoinWithPredicate() {
+ tester(false, new HrSchema())
+ .query(
+ "select e.empid, e.name, d.name as dept from emps e join depts d"
+ + " on e.deptno=d.deptno and e.empid<150 and e.empid>d.deptno")
+ .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0],
name=[$t2], "
+ + "dept=[$t4])\n"
+ + " EnumerableHashJoin(condition=[AND(=($1, $3), >($0, $3))],
joinType=[inner])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[150],
expr#6=[<($t0, $t5)], "
+ + "proj#0..2=[{exprs}], $condition=[$t6])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n"
+ + " EnumerableCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}])\n"
+ + " EnumerableTableScan(table=[[s, depts]])\n")
+ .returnsUnordered(
+ "empid=100; name=Bill; dept=Sales",
+ "empid=110; name=Theodore; dept=Sales");
+ }
+
+ @Test void innerJoinWithCompositeKeyAndNullValues() {
+ tester(false, new HrSchema())
+ .query(
+ "select e1.empid from emps e1 join emps e2 "
+ + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+ planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+ .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0])\n"
+ + " EnumerableHashJoin(condition=[AND(=($1, $3), =($2, $4))],
joinType=[inner])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], deptno=[$t1],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n")
+ .returnsUnordered(
+ "empid=100",
+ "empid=110",
+ "empid=200");
+ }
+
+ @Test void leftOuterJoinWithCompositeKeyAndNullValues() {
+ tester(false, new HrSchema())
+ .query(
+ "select e1.empid, e2.empid from emps e1 left outer join emps e2 "
+ + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+ planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+ .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0],
empid0=[$t3])\n"
+ + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))],
joinType=[left])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n")
+ .returnsUnordered(
+ "empid=100; empid=100",
+ "empid=110; empid=110",
+ "empid=150; empid=null",
+ "empid=200; empid=200");
+ }
+
+ @Test void rightOuterJoinWithCompositeKeyAndNullValues() {
+ tester(false, new HrSchema())
+ .query(
+ "select e1.empid, e2.empid from emps e1 right outer join emps e2 "
+ + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+ planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+ .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0],
empid0=[$t3])\n"
+ + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))],
joinType=[right])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n")
+ .returnsUnordered(
+ "empid=100; empid=100",
+ "empid=110; empid=110",
+ "empid=200; empid=200",
+ "empid=null; empid=150");
+ }
+
+ @Test void fullOuterJoinWithCompositeKeyAndNullValues() {
+ tester(false, new HrSchema())
+ .query(
+ "select e1.empid, e2.empid from emps e1 full outer join emps e2 "
+ + "on e1.deptno=e2.deptno and e1.commission=e2.commission")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner ->
+ planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE))
+ .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0],
empid0=[$t3])\n"
+ + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))],
joinType=[full])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n")
+ .returnsUnordered(
+ "empid=100; empid=100",
+ "empid=110; empid=110",
+ "empid=150; empid=null",
+ "empid=200; empid=200",
+ "empid=null; empid=150");
+ }
+
+ @Test void semiJoinWithCompositeKeyAndNullValues() {
+ tester(true, new HrSchema())
+ .query(
+ "select e1.empid from emps e1 where exists (select 1 from emps e2 "
+ + "where e1.deptno=e2.deptno and e1.commission=e2.commission)")
+ .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
+ planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE);
+ })
+ .explainContains("EnumerableCalc(expr#0..2=[{inputs}], empid=[$t0])\n"
+ + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $7))],
joinType=[semi])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n"
+ + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[IS NOT
NULL($t4)], proj#0..4=[{exprs}], $condition=[$t5])\n"
+ + " EnumerableTableScan(table=[[s, emps]])\n")
+ .returnsUnordered(
+ "empid=100",
+ "empid=110",
+ "empid=200");
+ }
+
private CalciteAssert.AssertThat tester(boolean forceDecorrelate,
Object schema) {
return CalciteAssert.that()
diff --git
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java
index 30f69f1034..a15af886d6 100644
---
a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java
+++
b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java
@@ -29,9 +29,11 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.test.CalciteAssert;
import org.apache.calcite.test.schemata.hr.HierarchySchema;
import org.apache.calcite.test.schemata.hr.HrSchema;
+import org.apache.calcite.test.schemata.hr.HrSchemaBig;
import org.junit.jupiter.api.Test;
+import java.util.Arrays;
import java.util.function.Consumer;
/**
@@ -218,18 +220,79 @@ class EnumerableJoinTest {
/** Test case for
* <a
href="https://issues.apache.org/jira/browse/CALCITE-3846">[CALCITE-3846]
* EnumerableMergeJoin: wrong comparison of composite key with null
values</a>. */
- @Test void testMergeJoinWithCompositeKeyAndNullValues() {
- tester(false, new HrSchema())
+ @Test void testMergeJoinInnerWithCompositeKeyAndNullValues() {
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ false,
+ JoinRelType.INNER,
+ "empid=110; empid0=110",
+ "empid=100; empid0=100",
+ "empid=200; empid0=200");
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ true,
+ JoinRelType.INNER,
+ "empid=48; empid0=48",
+ "empid=4; empid0=4",
+ "empid=4; empid0=8");
+ }
+
+ @Test void testMergeJoinLeftWithCompositeKeyAndNullValues() {
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ false,
+ JoinRelType.LEFT,
+ "empid=110; empid0=110",
+ "empid=100; empid0=100",
+ "empid=150; empid0=null",
+ "empid=200; empid0=200");
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ true,
+ JoinRelType.LEFT,
+ "empid=48; empid0=48",
+ "empid=47; empid0=null",
+ "empid=4; empid0=4");
+ }
+
+ @Test void testMergeJoinSemiWithCompositeKeyAndNullValues() {
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ false,
+ JoinRelType.SEMI,
+ "empid=110",
+ "empid=100",
+ "empid=200");
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ true,
+ JoinRelType.SEMI,
+ "empid=48",
+ "empid=4",
+ "empid=8");
+ }
+
+ @Test void testMergeJoinAntiWithCompositeKeyAndNullValues() {
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ false,
+ JoinRelType.ANTI,
+ "empid=150");
+ checkMergeJoinWithCompositeKeyAndNullValues(
+ true,
+ JoinRelType.ANTI,
+ "empid=47",
+ "empid=3",
+ "empid=7");
+ }
+
+ private void checkMergeJoinWithCompositeKeyAndNullValues(boolean bigSchema,
JoinRelType joinType,
+ String... expected) {
+ CalciteAssert.AssertQuery checker =
+ tester(false, bigSchema ? new HrSchemaBig() : new HrSchema())
.withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> {
planner.addRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE);
planner.removeRule(EnumerableRules.ENUMERABLE_JOIN_RULE);
})
.withRel(builder -> builder
- .scan("s", "emps")
- .sort(builder.field("deptno"), builder.field("commission"))
- .scan("s", "emps")
- .sort(builder.field("deptno"), builder.field("commission"))
- .join(JoinRelType.INNER,
+ .scan("s", "emps").as("e1")
+ .sort(builder.field("deptno"), builder.field("commission"),
builder.field("empid"))
+ .scan("s", "emps").as("e2")
+ .sort(builder.field("deptno"), builder.field("commission"),
builder.field("empid"))
+ .join(joinType,
builder.and(
builder.equals(
builder.field(2, 0, "deptno"),
@@ -237,22 +300,16 @@ class EnumerableJoinTest {
builder.equals(
builder.field(2, 0, "commission"),
builder.field(2, 1, "commission"))))
- .project(
- builder.field("empid"))
+ .project(joinType.projectsRight()
+ ? Arrays.asList(builder.field("e1", "empid"),
builder.field("e2", "empid"))
+ : Arrays.asList(builder.field("e1", "empid")))
.build())
- .explainHookMatches("" // It is important that we have MergeJoin in
the plan
- + "EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0])\n"
- + " EnumerableMergeJoin(condition=[AND(=($1, $3), =($2, $4))],
joinType=[inner])\n"
- + " EnumerableSort(sort0=[$1], sort1=[$2], dir0=[ASC],
dir1=[ASC])\n"
- + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}],
commission=[$t4])\n"
- + " EnumerableTableScan(table=[[s, emps]])\n"
- + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC],
dir1=[ASC])\n"
- + " EnumerableCalc(expr#0..4=[{inputs}], deptno=[$t1],
commission=[$t4])\n"
- + " EnumerableTableScan(table=[[s, emps]])\n")
- .returnsUnordered("empid=100",
- "empid=110",
- "empid=150",
- "empid=200");
+ .explainHookContains("EnumerableMergeJoin"); // We must have MergeJoin
in the plan
+ if (bigSchema) {
+ checker.returnsStartingWith(expected);
+ } else {
+ checker.returnsOrdered(expected);
+ }
}
/** Test case for
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index d623d8619d..851c0d40d3 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -1924,7 +1924,7 @@ public abstract class EnumerableDefaults {
final Predicate1<TSource> predicate = v0 -> {
TKey key = outerKeySelector.apply(v0);
@SuppressWarnings("argument.type.incompatible")
- Enumerable<TInner> innersOfKey = innerLookup.get().get(key);
+ Enumerable<TInner> innersOfKey = key == null ? null :
innerLookup.get().get(key);
if (innersOfKey == null) {
return anti;
}
@@ -1963,9 +1963,11 @@ public abstract class EnumerableDefaults {
? inner.select(innerKeySelector).distinct()
: inner.select(innerKeySelector).distinct(comparer));
- final Predicate1<TSource> predicate = anti
- ? v0 -> !innerLookup.get().contains(outerKeySelector.apply(v0))
- : v0 -> innerLookup.get().contains(outerKeySelector.apply(v0));
+ final Predicate1<TSource> predicate = v0 -> {
+ TKey key = outerKeySelector.apply(v0);
+ boolean found = key != null && innerLookup.get().contains(key);
+ return anti ? !found : found;
+ };
return EnumerableDefaults.where(outer.enumerator(), predicate);
}
@@ -2146,9 +2148,9 @@ public abstract class EnumerableDefaults {
/**
* Joins two inputs that are sorted on the key.
- * Inputs must sorted in ascending order, nulls last.
+ * Inputs must be sorted in ascending order, nulls last.
*
- * @deprecated Use {@link #mergeJoin(Enumerable, Enumerable, Function1,
Function1, Function2, JoinType, Comparator)}
+ * @deprecated Use {@link #mergeJoin(Enumerable, Enumerable, Function1,
Function1, Predicate2, Function2, JoinType, Comparator, EqualityComparer)}
*/
@Deprecated // to be removed before 2.0
public static <TSource, TInner, TKey extends Comparable<TKey>, TResult>
Enumerable<TResult>
@@ -2168,7 +2170,7 @@ public abstract class EnumerableDefaults {
"not implemented, mergeJoin with generateNullsOnRight");
}
return mergeJoin(outer, inner, outerKeySelector, innerKeySelector, null,
resultSelector,
- JoinType.INNER, null);
+ JoinType.INNER, null, null);
}
/**
@@ -2191,7 +2193,7 @@ public abstract class EnumerableDefaults {
/**
* Joins two inputs that are sorted on the key.
- * Inputs must sorted in ascending order, nulls last.
+ * Inputs must be sorted in ascending order, nulls last.
*/
public static <TSource, TInner, TKey extends Comparable<TKey>, TResult>
Enumerable<TResult>
mergeJoin(final Enumerable<TSource> outer,
@@ -2202,7 +2204,7 @@ public abstract class EnumerableDefaults {
final JoinType joinType,
final Comparator<TKey> comparator) {
return mergeJoin(outer, inner, outerKeySelector, innerKeySelector, null,
resultSelector,
- joinType, comparator);
+ joinType, comparator, null);
}
/**
@@ -2219,6 +2221,8 @@ public abstract class EnumerableDefaults {
* types in the future (e.g. semi or anti joins).
* @param comparator key comparator, possibly null (in which case {@link
Comparable#compareTo}
* will be used).
+ * @param equalityComparer key equality comparer, possibly null (in which
case equals
+ * will be used), required to compare keys from the
same input.
*
* <p>NOTE: The current API is experimental and subject to change without
* notice.
@@ -2232,14 +2236,15 @@ public abstract class EnumerableDefaults {
final @Nullable Predicate2<TSource, TInner> extraPredicate,
final Function2<TSource, @Nullable TInner, TResult> resultSelector,
final JoinType joinType,
- final @Nullable Comparator<TKey> comparator) {
+ final @Nullable Comparator<TKey> comparator,
+ final @Nullable EqualityComparer<TKey> equalityComparer) {
if (!isMergeJoinSupported(joinType)) {
throw new UnsupportedOperationException("MergeJoin unsupported for join
type " + joinType);
}
return new AbstractEnumerable<TResult>() {
@Override public Enumerator<TResult> enumerator() {
return new MergeJoinEnumerator<>(outer, inner, outerKeySelector,
innerKeySelector,
- extraPredicate, resultSelector, joinType, comparator);
+ extraPredicate, resultSelector, joinType, comparator,
equalityComparer);
}
};
}
@@ -4144,7 +4149,7 @@ public abstract class EnumerableDefaults {
}
/** Enumerator that performs a merge join on its sorted inputs.
- * Inputs must sorted in ascending order, nulls last.
+ * Inputs must be sorted in ascending order, nulls last.
*
* @param <TResult> result type
* @param <TSource> left input record type
@@ -4167,6 +4172,8 @@ public abstract class EnumerableDefaults {
private final JoinType joinType;
// key comparator, possibly null (Comparable#compareTo to be used in that
case)
private final @Nullable Comparator<TKey> comparator;
+ // key equality comparer, possibly null (equals to be used in that case)
+ private final @Nullable EqualityComparer<TKey> equalityComparer;
private boolean done;
private @Nullable Enumerator<TResult> results = null;
// used for LEFT/ANTI join: if right input is over, all remaining elements
from left are results
@@ -4181,7 +4188,8 @@ public abstract class EnumerableDefaults {
@Nullable Predicate2<TSource, TInner> extraPredicate,
Function2<TSource, @Nullable TInner, TResult> resultSelector,
JoinType joinType,
- @Nullable Comparator<TKey> comparator) {
+ @Nullable Comparator<TKey> comparator,
+ @Nullable EqualityComparer<TKey> equalityComparer) {
this.leftEnumerable = leftEnumerable;
this.rightEnumerable = rightEnumerable;
this.outerKeySelector = outerKeySelector;
@@ -4190,6 +4198,7 @@ public abstract class EnumerableDefaults {
this.resultSelector = resultSelector;
this.joinType = joinType;
this.comparator = comparator;
+ this.equalityComparer = equalityComparer;
start();
}
@@ -4258,15 +4267,18 @@ public abstract class EnumerableDefaults {
results = Linq4j.emptyEnumerator();
}
+ /** Method to compare keys from left and right input (nulls must not be
considered equal). */
private int compare(TKey key1, TKey key2) {
- return comparator != null ? comparator.compare(key1, key2) :
compareNullsLast(key1, key2);
+ return comparator != null
+ ? comparator.compare(key1, key2)
+ : compareNullsLastForMergeJoin(key1, key2);
}
- private int compareNullsLast(TKey v0, TKey v1) {
- return v0 == v1 ? 0
- : v0 == null ? 1
- : v1 == null ? -1
- : v0.compareTo(v1);
+ /** Method to compare keys from the same input (nulls must be considered
equal). */
+ private boolean compareEquals(TKey key1, TKey key2) {
+ return equalityComparer != null
+ ? equalityComparer.equal(key1, key2)
+ : Objects.equals(key1, key2);
}
/** Moves to the next key that is present in both sides. Populates
@@ -4291,7 +4303,14 @@ public abstract class EnumerableDefaults {
done = true;
return false;
}
- int c = compare(leftKey, rightKey);
+ int c;
+ try {
+ c = compare(leftKey, rightKey);
+ } catch (BothValuesAreNullException e) {
+ // consider the first (left) value as "bigger", to advance on the
right value
+ // and continue with the algorithm
+ c = 1;
+ }
if (c == 0) {
break;
}
@@ -4390,13 +4409,7 @@ public abstract class EnumerableDefaults {
// if we reach a null key, we are done (except LEFT join, that needs
to process LHS fully)
break;
}
- int c = compare(leftKey, leftKey2);
- if (c != 0) {
- if (c > 0) {
- throw new IllegalStateException(
- "mergeJoin assumes inputs sorted in ascending order, " +
"however '"
- + leftKey + "' is greater than '" + leftKey2 + "'");
- }
+ if (!compareEquals(leftKey, leftKey2)) {
return true;
}
lefts.add(left);
@@ -4423,13 +4436,7 @@ public abstract class EnumerableDefaults {
// if we reach a null key, we are done
break;
}
- int c = compare(rightKey, rightKey2);
- if (c != 0) {
- if (c > 0) {
- throw new IllegalStateException(
- "mergeJoin assumes input sorted in ascending order, " +
"however '"
- + rightKey + "' is greater than '" + rightKey2 + "'");
- }
+ if (!compareEquals(rightKey, rightKey2)) {
return true;
}
rights.add(right);
@@ -4495,6 +4502,36 @@ public abstract class EnumerableDefaults {
}
}
+ /**
+ * Exception used for control flow (it does not populate the stack trace to
be more efficient)
+ * to signal that both values are <code>null</code>, so that the caller
method (i.e. MergeJoin
+ * algorithm) can act accordingly.
+ */
+ private static class BothValuesAreNullException extends RuntimeException {
+ @Override public synchronized Throwable fillInStackTrace() {
+ return this;
+ }
+ }
+
+ public static int compareNullsLastForMergeJoin(@Nullable Comparable v0,
@Nullable Comparable v1) {
+ return compareNullsLastForMergeJoin(v0, v1, null);
+ }
+
+ public static int compareNullsLastForMergeJoin(@Nullable Comparable v0,
@Nullable Comparable v1,
+ @Nullable Comparator comparator) {
+ // Special code for mergeJoin algorithm: in case of two null values, they
must not be
+ // considered as equal (otherwise the join would return incorrect results)
+ if (v0 == null && v1 == null) {
+ throw new BothValuesAreNullException();
+ }
+
+ //noinspection unchecked
+ return v0 == v1 ? 0
+ : v0 == null ? 1
+ : v1 == null ? -1
+ : comparator == null ? v0.compareTo(v1) :
comparator.compare(v0, v1);
+ }
+
/** Enumerates the elements of a cartesian product of two inputs.
*
* @param <TResult> result type