This is an automated email from the ASF dual-hosted git repository. rubenql pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push: new 3aee0b86aa [CALCITE-5732] EnumerableHashJoin and EnumerableMergeJoin on composite key return rows matching condition 'null = null' 3aee0b86aa is described below commit 3aee0b86aa23476cbdecc75ad5d43b936a6fff7b Author: rubenada <rube...@gmail.com> AuthorDate: Wed Jul 12 14:45:05 2023 +0100 [CALCITE-5732] EnumerableHashJoin and EnumerableMergeJoin on composite key return rows matching condition 'null = null' --- .../adapter/enumerable/EnumerableHashJoin.java | 8 +- .../adapter/enumerable/EnumerableMergeJoin.java | 7 +- .../calcite/adapter/enumerable/PhysType.java | 26 +++- .../calcite/adapter/enumerable/PhysTypeImpl.java | 168 +++++++++++++-------- .../java/org/apache/calcite/runtime/Utilities.java | 11 ++ .../org/apache/calcite/util/BuiltInMethod.java | 2 +- .../apache/calcite/runtime/EnumerablesTest.java | 28 ++-- .../test/enumerable/EnumerableHashJoinTest.java | 136 ++++++++++++++--- .../test/enumerable/EnumerableJoinTest.java | 101 ++++++++++--- .../apache/calcite/linq4j/EnumerableDefaults.java | 105 ++++++++----- 10 files changed, 430 insertions(+), 162 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java index 259fb68503..cfe44da75d 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableHashJoin.java @@ -218,8 +218,8 @@ public class EnumerableHashJoin extends Join implements EnumerableRel { Expressions.list( leftExpression, rightExpression, - leftResult.physType.generateAccessor(joinInfo.leftKeys), - rightResult.physType.generateAccessor(joinInfo.rightKeys), + leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys), + rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys), Util.first(keyPhysType.comparer(), Expressions.constant(null)), predicate))) @@ -264,8 +264,8 @@ public class EnumerableHashJoin extends Join implements EnumerableRel { BuiltInMethod.HASH_JOIN.method, Expressions.list( rightExpression, - leftResult.physType.generateAccessor(joinInfo.leftKeys), - rightResult.physType.generateAccessor(joinInfo.rightKeys), + leftResult.physType.generateAccessorWithoutNulls(joinInfo.leftKeys), + rightResult.physType.generateAccessorWithoutNulls(joinInfo.rightKeys), EnumUtils.joinSelector(joinType, physType, ImmutableList.of( diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java index d3f45d8d31..3f42e065c8 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java @@ -492,7 +492,7 @@ public class EnumerableMergeJoin extends Join implements EnumerableRel { RelFieldCollation.NullDirection.LAST)); } final RelCollation collation = RelCollations.of(fieldCollations); - final Expression comparator = leftKeyPhysType.generateComparator(collation); + final Expression comparator = leftKeyPhysType.generateMergeJoinComparator(collation); return implementor.result( physType, @@ -512,6 +512,9 @@ public class EnumerableMergeJoin extends Join implements EnumerableRel { ImmutableList.of( leftResult.physType, rightResult.physType)), Expressions.constant(EnumUtils.toLinq4jJoinType(joinType)), - comparator))).toBlock()); + comparator, + Util.first( + leftKeyPhysType.comparer(), + Expressions.constant(null))))).toBlock()); } } diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java index 4447e91dfe..8d4eeb176c 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysType.java @@ -112,11 +112,28 @@ public interface PhysType { * public Object[] apply(Employee v1) { * return FlatLists.of(v1.<fieldN>, v1.<fieldM>); * } - * } * }</pre></blockquote> */ Expression generateAccessor(List<Integer> fields); + /** Similar to {@link #generateAccessor(List)}, but if one of the fields is <code>null</code>, + * it will return <code>null</code>. + * + * <p>For example: + * + * <blockquote><pre> + * new Function1<Employee, Object[]> { + * public Object[] apply(Employee v1) { + * return v1.<fieldN> == null + * ? null + * : v1.<fieldM> == null + * ? null + * : FlatLists.of(v1.<fieldN>, v1.<fieldM>); + * } + * }</pre></blockquote> + */ + Expression generateAccessorWithoutNulls(List<Integer> fields); + /** Generates a selector for the given fields from an expression, with the * default row format. */ Expression generateSelector( @@ -181,6 +198,13 @@ public interface PhysType { Expression generateComparator( RelCollation collation); + /** Similar to {@link #generateComparator(RelCollation)}, but with some specificities for + * MergeJoin algorithm: it will not consider two <code>null</code> values as equal. + * + * @see org.apache.calcite.linq4j.EnumerableDefaults#compareNullsLastForMergeJoin + */ + Expression generateMergeJoinComparator(RelCollation collation); + /** Returns a expression that yields a comparer, or null if this type * is comparable. */ @Nullable Expression comparer(); diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java index 55a44b2a29..ab51dd3e35 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/PhysTypeImpl.java @@ -407,6 +407,24 @@ public class PhysTypeImpl implements PhysType { } @Override public Expression generateComparator(RelCollation collation) { + return this.generateComparator(collation, fieldCollation -> { + final int index = fieldCollation.getFieldIndex(); + final boolean nullsFirst = + fieldCollation.nullDirection + == RelFieldCollation.NullDirection.FIRST; + final boolean descending = + fieldCollation.getDirection() + == RelFieldCollation.Direction.DESCENDING; + return fieldNullable(index) + ? (nullsFirst != descending + ? "compareNullsFirst" + : "compareNullsLast") + : "compare"; + }); + } + + private Expression generateComparator(RelCollation collation, + Function1<RelFieldCollation, String> compareMethodNameFunction) { // int c; // c = Utilities.compare(v0, v1); // if (c != 0) return c; // or -c if descending @@ -437,9 +455,6 @@ public class PhysTypeImpl implements PhysType { default: break; } - final boolean nullsFirst = - fieldCollation.nullDirection - == RelFieldCollation.NullDirection.FIRST; final boolean descending = fieldCollation.getDirection() == RelFieldCollation.Direction.DESCENDING; @@ -449,11 +464,7 @@ public class PhysTypeImpl implements PhysType { parameterC, Expressions.call( Utilities.class, - fieldNullable(index) - ? (nullsFirst != descending - ? "compareNullsFirst" - : "compareNullsLast") - : "compare", + compareMethodNameFunction.apply(fieldCollation), Expressions.list( arg0, arg1) @@ -511,6 +522,17 @@ public class PhysTypeImpl implements PhysType { memberDeclarations); } + @Override public Expression generateMergeJoinComparator(RelCollation collation) { + return this.generateComparator(collation, fieldCollation -> { + // merge join keys must be sorted in ascending order, nulls last + assert fieldCollation.nullDirection == RelFieldCollation.NullDirection.LAST; + assert fieldCollation.getDirection() == RelFieldCollation.Direction.ASCENDING; + return fieldNullable(fieldCollation.getFieldIndex()) + ? "compareNullsLastForMergeJoin" + : "compare"; + }); + } + @Override public RelDataType getRowType() { return rowType; } @@ -616,65 +638,79 @@ public class PhysTypeImpl implements PhysType { for (int field : fields) { list.add(fieldReference(v1, field)); } - switch (list.size()) { - case 2: - return Expressions.lambda( - Function1.class, - Expressions.call( - List.class, - null, - BuiltInMethod.LIST2.method, - list), - v1); - case 3: - return Expressions.lambda( - Function1.class, - Expressions.call( - List.class, - null, - BuiltInMethod.LIST3.method, - list), - v1); - case 4: - return Expressions.lambda( - Function1.class, - Expressions.call( - List.class, - null, - BuiltInMethod.LIST4.method, - list), - v1); - case 5: - return Expressions.lambda( - Function1.class, - Expressions.call( - List.class, - null, - BuiltInMethod.LIST5.method, - list), - v1); - case 6: - return Expressions.lambda( - Function1.class, - Expressions.call( - List.class, - null, - BuiltInMethod.LIST6.method, - list), - v1); - default: - return Expressions.lambda( - Function1.class, - Expressions.call( - List.class, - null, - BuiltInMethod.LIST_N.method, - Expressions.newArrayInit( - Comparable.class, - list)), - v1); - } + return Expressions.lambda(Function1.class, getListExpression(list), v1); + } + } + + private static Expression getListExpression(Expressions.FluentList<Expression> list) { + assert list.size() >= 2; + + switch (list.size()) { + case 2: + return Expressions.call( + List.class, + null, + BuiltInMethod.LIST2.method, + list); + case 3: + return Expressions.call( + List.class, + null, + BuiltInMethod.LIST3.method, + list); + case 4: + return Expressions.call( + List.class, + null, + BuiltInMethod.LIST4.method, + list); + case 5: + return Expressions.call( + List.class, + null, + BuiltInMethod.LIST5.method, + list); + case 6: + return Expressions.call( + List.class, + null, + BuiltInMethod.LIST6.method, + list); + default: + return Expressions.call( + List.class, + null, + BuiltInMethod.LIST_N.method, + Expressions.newArrayInit(Comparable.class, list)); + } + } + + @Override public Expression generateAccessorWithoutNulls(List<Integer> fields) { + if (fields.size() < 2) { + return generateAccessor(fields); + } + + ParameterExpression v1 = Expressions.parameter(javaRowClass, "v1"); + Expressions.FluentList<Expression> list = Expressions.list(); + for (int field : fields) { + list.add(fieldReference(v1, field)); + } + + // (v1.<field0> == null) + // ? null + // : (v1.<field1> == null) + // ? null; + // : ... + // : FlatLists.of(...); + Expression exp = getListExpression(list); + for (int i = list.size() - 1; i >= 0; i--) { + exp = + Expressions.condition( + Expressions.equal(list.get(i), Expressions.constant(null)), + Expressions.constant(null), + exp); } + return Expressions.lambda(Function1.class, exp, v1); } @Override public Expression fieldReference( diff --git a/core/src/main/java/org/apache/calcite/runtime/Utilities.java b/core/src/main/java/org/apache/calcite/runtime/Utilities.java index 458a07c898..3409af7224 100644 --- a/core/src/main/java/org/apache/calcite/runtime/Utilities.java +++ b/core/src/main/java/org/apache/calcite/runtime/Utilities.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.runtime; +import org.apache.calcite.linq4j.EnumerableDefaults; + import org.checkerframework.checker.nullness.qual.Nullable; import java.text.Collator; @@ -250,6 +252,15 @@ public class Utilities { : FlatLists.ComparableListImpl.compare(v0, v1); } + public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, @Nullable Comparable v1) { + return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1); + } + + public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, @Nullable Comparable v1, + Comparator comparator) { + return EnumerableDefaults.compareNullsLastForMergeJoin(v0, v1, comparator); + } + /** Creates a pattern builder. */ public static Pattern.PatternBuilder patternBuilder() { return Pattern.builder(); diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 6d7eedd10f..19c5228e09 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -205,7 +205,7 @@ public enum BuiltInMethod { List.class, int.class, Consumer.class), MERGE_JOIN(EnumerableDefaults.class, "mergeJoin", Enumerable.class, Enumerable.class, Function1.class, Function1.class, Predicate2.class, Function2.class, - JoinType.class, Comparator.class), + JoinType.class, Comparator.class, EqualityComparer.class), SLICE0(Enumerables.class, "slice0", Enumerable.class), SEMI_JOIN(EnumerableDefaults.class, "semiJoin", Enumerable.class, Enumerable.class, Function1.class, Function1.class, diff --git a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java index 13970b53a7..0010482f95 100644 --- a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java +++ b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java @@ -414,7 +414,7 @@ class EnumerablesTest { e1 -> e1.name, e2 -> e2.name, (e1, e2) -> e1.deptno < e2.deptno, - (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(), + (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(), hasToString("[" + "Emp(1, Fred)-Emp(2, Fred), " + "Emp(1, Fred)-Emp(3, Fred), " @@ -430,7 +430,7 @@ class EnumerablesTest { e2 -> e2.name, e1 -> e1.name, (e2, e1) -> e2.deptno > e1.deptno, - (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(), + (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(), hasToString("[" + "Emp(2, Fred)-Emp(1, Fred), " + "Emp(3, Fred)-Emp(1, Fred), " @@ -446,7 +446,7 @@ class EnumerablesTest { e1 -> e1.name, e2 -> e2.name, (e1, e2) -> e1.deptno == e2.deptno * 2, - (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(), + (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(), hasToString("[]")); assertThat( @@ -456,7 +456,7 @@ class EnumerablesTest { e2 -> e2.name, e1 -> e1.name, (e2, e1) -> e2.deptno == e1.deptno * 2, - (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(), + (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(), hasToString("[Emp(2, Fred)-Emp(1, Fred)]")); assertThat( @@ -466,7 +466,7 @@ class EnumerablesTest { e2 -> e2.name, e1 -> e1.name, (e2, e1) -> e2.deptno == e1.deptno + 2, - (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null).toList(), + (v0, v1) -> v0 + "-" + v1, JoinType.INNER, null, null).toList(), hasToString("[Emp(3, Fred)-Emp(1, Fred), Emp(5, Joe)-Emp(3, Joe)]")); } @@ -493,7 +493,7 @@ class EnumerablesTest { null, (v0, v1) -> v0, JoinType.SEMI, - null).toList(), + null, null).toList(), hasToString("[Dept(10, Marketing)," + " Dept(20, Sales)," + " Dept(30, Research)]")); @@ -522,7 +522,7 @@ class EnumerablesTest { (d, e) -> e.name.contains("a"), (v0, v1) -> v0, JoinType.SEMI, - null).toList(), + null, null).toList(), hasToString("[Dept(20, Sales)]")); } @@ -550,7 +550,7 @@ class EnumerablesTest { (e, d) -> e.name.startsWith("T"), (v0, v1) -> v0, JoinType.SEMI, - null).toList(), + null, null).toList(), hasToString("[Emp(30, Theodore)]")); } @@ -578,7 +578,7 @@ class EnumerablesTest { null, (v0, v1) -> v0, JoinType.ANTI, - null).toList(), + null, null).toList(), hasToString("[Dept(25, HR), Dept(40, Development)]")); } @@ -605,7 +605,7 @@ class EnumerablesTest { (d, e) -> e.name.startsWith("F") || e.name.startsWith("S"), (v0, v1) -> v0, JoinType.ANTI, - null).toList(), + null, null).toList(), hasToString("[Dept(25, HR), Dept(30, Research), " + "Dept(40, Development)]")); } @@ -634,7 +634,7 @@ class EnumerablesTest { (e, d) -> d.deptno < 30, (v0, v1) -> v0, JoinType.ANTI, - null).toList(), + null, null).toList(), hasToString("[Emp(30, Fred), Emp(20, Sebastian), Emp(20, Zoey)]")); } @@ -661,7 +661,7 @@ class EnumerablesTest { null, (v0, v1) -> v0 + "-" + v1, JoinType.LEFT, - null).toList(), + null, null).toList(), hasToString("[Dept(10, Marketing)-Emp(10, Fred)," + " Dept(20, Sales)-Emp(20, Theodore)," + " Dept(20, Sales)-Emp(20, Sebastian)," @@ -694,7 +694,7 @@ class EnumerablesTest { (d, e) -> e.name.contains("a"), (v0, v1) -> v0 + "-" + v1, JoinType.LEFT, - null).toList(), + null, null).toList(), hasToString("[Dept(10, Marketing)-null," + " Dept(20, Sales)-Emp(20, Sebastian)," + " Dept(25, HR)-null," @@ -726,7 +726,7 @@ class EnumerablesTest { (e, d) -> e.name.startsWith("T"), (v0, v1) -> v0 + "-" + v1, JoinType.LEFT, - null).toList(), + null, null).toList(), hasToString("[Emp(30, Fred)-null," + " Emp(20, Sebastian)-null," + " Emp(30, Theodore)-Dept(30, Theodore)," diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java index 1e8be1b01c..1207e405c2 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableHashJoinTest.java @@ -56,24 +56,6 @@ class EnumerableHashJoinTest { "empid=150; name=Sebastian; dept=Sales"); } - @Test void innerJoinWithPredicate() { - tester(false, new HrSchema()) - .query( - "select e.empid, e.name, d.name as dept from emps e join depts d" - + " on e.deptno=d.deptno and e.empid<150 and e.empid>d.deptno") - .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], name=[$t2], " - + "dept=[$t4])\n" - + " EnumerableHashJoin(condition=[AND(=($1, $3), >($0, $3))], joinType=[inner])\n" - + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[150], expr#6=[<($t0, $t5)], " - + "proj#0..2=[{exprs}], $condition=[$t6])\n" - + " EnumerableTableScan(table=[[s, emps]])\n" - + " EnumerableCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}])\n" - + " EnumerableTableScan(table=[[s, depts]])\n") - .returnsUnordered( - "empid=100; name=Bill; dept=Sales", - "empid=110; name=Theodore; dept=Sales"); - } - @Test void leftOuterJoin() { tester(false, new HrSchema()) .query( @@ -201,6 +183,124 @@ class EnumerableHashJoinTest { "name=Sebastian; salary=7000.0"); } + @Test void innerJoinWithPredicate() { + tester(false, new HrSchema()) + .query( + "select e.empid, e.name, d.name as dept from emps e join depts d" + + " on e.deptno=d.deptno and e.empid<150 and e.empid>d.deptno") + .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], name=[$t2], " + + "dept=[$t4])\n" + + " EnumerableHashJoin(condition=[AND(=($1, $3), >($0, $3))], joinType=[inner])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[150], expr#6=[<($t0, $t5)], " + + "proj#0..2=[{exprs}], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}])\n" + + " EnumerableTableScan(table=[[s, depts]])\n") + .returnsUnordered( + "empid=100; name=Bill; dept=Sales", + "empid=110; name=Theodore; dept=Sales"); + } + + @Test void innerJoinWithCompositeKeyAndNullValues() { + tester(false, new HrSchema()) + .query( + "select e1.empid from emps e1 join emps e2 " + + "on e1.deptno=e2.deptno and e1.commission=e2.commission") + .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> + planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE)) + .explainContains("EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0])\n" + + " EnumerableHashJoin(condition=[AND(=($1, $3), =($2, $4))], joinType=[inner])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], deptno=[$t1], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsUnordered( + "empid=100", + "empid=110", + "empid=200"); + } + + @Test void leftOuterJoinWithCompositeKeyAndNullValues() { + tester(false, new HrSchema()) + .query( + "select e1.empid, e2.empid from emps e1 left outer join emps e2 " + + "on e1.deptno=e2.deptno and e1.commission=e2.commission") + .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> + planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE)) + .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0], empid0=[$t3])\n" + + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))], joinType=[left])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsUnordered( + "empid=100; empid=100", + "empid=110; empid=110", + "empid=150; empid=null", + "empid=200; empid=200"); + } + + @Test void rightOuterJoinWithCompositeKeyAndNullValues() { + tester(false, new HrSchema()) + .query( + "select e1.empid, e2.empid from emps e1 right outer join emps e2 " + + "on e1.deptno=e2.deptno and e1.commission=e2.commission") + .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> + planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE)) + .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0], empid0=[$t3])\n" + + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))], joinType=[right])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsUnordered( + "empid=100; empid=100", + "empid=110; empid=110", + "empid=200; empid=200", + "empid=null; empid=150"); + } + + @Test void fullOuterJoinWithCompositeKeyAndNullValues() { + tester(false, new HrSchema()) + .query( + "select e1.empid, e2.empid from emps e1 full outer join emps e2 " + + "on e1.deptno=e2.deptno and e1.commission=e2.commission") + .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> + planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE)) + .explainContains("EnumerableCalc(expr#0..5=[{inputs}], empid=[$t0], empid0=[$t3])\n" + + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $5))], joinType=[full])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsUnordered( + "empid=100; empid=100", + "empid=110; empid=110", + "empid=150; empid=null", + "empid=200; empid=200", + "empid=null; empid=150"); + } + + @Test void semiJoinWithCompositeKeyAndNullValues() { + tester(true, new HrSchema()) + .query( + "select e1.empid from emps e1 where exists (select 1 from emps e2 " + + "where e1.deptno=e2.deptno and e1.commission=e2.commission)") + .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> { + planner.removeRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE); + }) + .explainContains("EnumerableCalc(expr#0..2=[{inputs}], empid=[$t0])\n" + + " EnumerableHashJoin(condition=[AND(=($1, $4), =($2, $7))], joinType=[semi])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=[IS NOT NULL($t4)], proj#0..4=[{exprs}], $condition=[$t5])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsUnordered( + "empid=100", + "empid=110", + "empid=200"); + } + private CalciteAssert.AssertThat tester(boolean forceDecorrelate, Object schema) { return CalciteAssert.that() diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java index 30f69f1034..a15af886d6 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableJoinTest.java @@ -29,9 +29,11 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.test.CalciteAssert; import org.apache.calcite.test.schemata.hr.HierarchySchema; import org.apache.calcite.test.schemata.hr.HrSchema; +import org.apache.calcite.test.schemata.hr.HrSchemaBig; import org.junit.jupiter.api.Test; +import java.util.Arrays; import java.util.function.Consumer; /** @@ -218,18 +220,79 @@ class EnumerableJoinTest { /** Test case for * <a href="https://issues.apache.org/jira/browse/CALCITE-3846">[CALCITE-3846] * EnumerableMergeJoin: wrong comparison of composite key with null values</a>. */ - @Test void testMergeJoinWithCompositeKeyAndNullValues() { - tester(false, new HrSchema()) + @Test void testMergeJoinInnerWithCompositeKeyAndNullValues() { + checkMergeJoinWithCompositeKeyAndNullValues( + false, + JoinRelType.INNER, + "empid=110; empid0=110", + "empid=100; empid0=100", + "empid=200; empid0=200"); + checkMergeJoinWithCompositeKeyAndNullValues( + true, + JoinRelType.INNER, + "empid=48; empid0=48", + "empid=4; empid0=4", + "empid=4; empid0=8"); + } + + @Test void testMergeJoinLeftWithCompositeKeyAndNullValues() { + checkMergeJoinWithCompositeKeyAndNullValues( + false, + JoinRelType.LEFT, + "empid=110; empid0=110", + "empid=100; empid0=100", + "empid=150; empid0=null", + "empid=200; empid0=200"); + checkMergeJoinWithCompositeKeyAndNullValues( + true, + JoinRelType.LEFT, + "empid=48; empid0=48", + "empid=47; empid0=null", + "empid=4; empid0=4"); + } + + @Test void testMergeJoinSemiWithCompositeKeyAndNullValues() { + checkMergeJoinWithCompositeKeyAndNullValues( + false, + JoinRelType.SEMI, + "empid=110", + "empid=100", + "empid=200"); + checkMergeJoinWithCompositeKeyAndNullValues( + true, + JoinRelType.SEMI, + "empid=48", + "empid=4", + "empid=8"); + } + + @Test void testMergeJoinAntiWithCompositeKeyAndNullValues() { + checkMergeJoinWithCompositeKeyAndNullValues( + false, + JoinRelType.ANTI, + "empid=150"); + checkMergeJoinWithCompositeKeyAndNullValues( + true, + JoinRelType.ANTI, + "empid=47", + "empid=3", + "empid=7"); + } + + private void checkMergeJoinWithCompositeKeyAndNullValues(boolean bigSchema, JoinRelType joinType, + String... expected) { + CalciteAssert.AssertQuery checker = + tester(false, bigSchema ? new HrSchemaBig() : new HrSchema()) .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> { planner.addRule(EnumerableRules.ENUMERABLE_MERGE_JOIN_RULE); planner.removeRule(EnumerableRules.ENUMERABLE_JOIN_RULE); }) .withRel(builder -> builder - .scan("s", "emps") - .sort(builder.field("deptno"), builder.field("commission")) - .scan("s", "emps") - .sort(builder.field("deptno"), builder.field("commission")) - .join(JoinRelType.INNER, + .scan("s", "emps").as("e1") + .sort(builder.field("deptno"), builder.field("commission"), builder.field("empid")) + .scan("s", "emps").as("e2") + .sort(builder.field("deptno"), builder.field("commission"), builder.field("empid")) + .join(joinType, builder.and( builder.equals( builder.field(2, 0, "deptno"), @@ -237,22 +300,16 @@ class EnumerableJoinTest { builder.equals( builder.field(2, 0, "commission"), builder.field(2, 1, "commission")))) - .project( - builder.field("empid")) + .project(joinType.projectsRight() + ? Arrays.asList(builder.field("e1", "empid"), builder.field("e2", "empid")) + : Arrays.asList(builder.field("e1", "empid"))) .build()) - .explainHookMatches("" // It is important that we have MergeJoin in the plan - + "EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0])\n" - + " EnumerableMergeJoin(condition=[AND(=($1, $3), =($2, $4))], joinType=[inner])\n" - + " EnumerableSort(sort0=[$1], sort1=[$2], dir0=[ASC], dir1=[ASC])\n" - + " EnumerableCalc(expr#0..4=[{inputs}], proj#0..1=[{exprs}], commission=[$t4])\n" - + " EnumerableTableScan(table=[[s, emps]])\n" - + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])\n" - + " EnumerableCalc(expr#0..4=[{inputs}], deptno=[$t1], commission=[$t4])\n" - + " EnumerableTableScan(table=[[s, emps]])\n") - .returnsUnordered("empid=100", - "empid=110", - "empid=150", - "empid=200"); + .explainHookContains("EnumerableMergeJoin"); // We must have MergeJoin in the plan + if (bigSchema) { + checker.returnsStartingWith(expected); + } else { + checker.returnsOrdered(expected); + } } /** Test case for diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index d623d8619d..851c0d40d3 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -1924,7 +1924,7 @@ public abstract class EnumerableDefaults { final Predicate1<TSource> predicate = v0 -> { TKey key = outerKeySelector.apply(v0); @SuppressWarnings("argument.type.incompatible") - Enumerable<TInner> innersOfKey = innerLookup.get().get(key); + Enumerable<TInner> innersOfKey = key == null ? null : innerLookup.get().get(key); if (innersOfKey == null) { return anti; } @@ -1963,9 +1963,11 @@ public abstract class EnumerableDefaults { ? inner.select(innerKeySelector).distinct() : inner.select(innerKeySelector).distinct(comparer)); - final Predicate1<TSource> predicate = anti - ? v0 -> !innerLookup.get().contains(outerKeySelector.apply(v0)) - : v0 -> innerLookup.get().contains(outerKeySelector.apply(v0)); + final Predicate1<TSource> predicate = v0 -> { + TKey key = outerKeySelector.apply(v0); + boolean found = key != null && innerLookup.get().contains(key); + return anti ? !found : found; + }; return EnumerableDefaults.where(outer.enumerator(), predicate); } @@ -2146,9 +2148,9 @@ public abstract class EnumerableDefaults { /** * Joins two inputs that are sorted on the key. - * Inputs must sorted in ascending order, nulls last. + * Inputs must be sorted in ascending order, nulls last. * - * @deprecated Use {@link #mergeJoin(Enumerable, Enumerable, Function1, Function1, Function2, JoinType, Comparator)} + * @deprecated Use {@link #mergeJoin(Enumerable, Enumerable, Function1, Function1, Predicate2, Function2, JoinType, Comparator, EqualityComparer)} */ @Deprecated // to be removed before 2.0 public static <TSource, TInner, TKey extends Comparable<TKey>, TResult> Enumerable<TResult> @@ -2168,7 +2170,7 @@ public abstract class EnumerableDefaults { "not implemented, mergeJoin with generateNullsOnRight"); } return mergeJoin(outer, inner, outerKeySelector, innerKeySelector, null, resultSelector, - JoinType.INNER, null); + JoinType.INNER, null, null); } /** @@ -2191,7 +2193,7 @@ public abstract class EnumerableDefaults { /** * Joins two inputs that are sorted on the key. - * Inputs must sorted in ascending order, nulls last. + * Inputs must be sorted in ascending order, nulls last. */ public static <TSource, TInner, TKey extends Comparable<TKey>, TResult> Enumerable<TResult> mergeJoin(final Enumerable<TSource> outer, @@ -2202,7 +2204,7 @@ public abstract class EnumerableDefaults { final JoinType joinType, final Comparator<TKey> comparator) { return mergeJoin(outer, inner, outerKeySelector, innerKeySelector, null, resultSelector, - joinType, comparator); + joinType, comparator, null); } /** @@ -2219,6 +2221,8 @@ public abstract class EnumerableDefaults { * types in the future (e.g. semi or anti joins). * @param comparator key comparator, possibly null (in which case {@link Comparable#compareTo} * will be used). + * @param equalityComparer key equality comparer, possibly null (in which case equals + * will be used), required to compare keys from the same input. * * <p>NOTE: The current API is experimental and subject to change without * notice. @@ -2232,14 +2236,15 @@ public abstract class EnumerableDefaults { final @Nullable Predicate2<TSource, TInner> extraPredicate, final Function2<TSource, @Nullable TInner, TResult> resultSelector, final JoinType joinType, - final @Nullable Comparator<TKey> comparator) { + final @Nullable Comparator<TKey> comparator, + final @Nullable EqualityComparer<TKey> equalityComparer) { if (!isMergeJoinSupported(joinType)) { throw new UnsupportedOperationException("MergeJoin unsupported for join type " + joinType); } return new AbstractEnumerable<TResult>() { @Override public Enumerator<TResult> enumerator() { return new MergeJoinEnumerator<>(outer, inner, outerKeySelector, innerKeySelector, - extraPredicate, resultSelector, joinType, comparator); + extraPredicate, resultSelector, joinType, comparator, equalityComparer); } }; } @@ -4144,7 +4149,7 @@ public abstract class EnumerableDefaults { } /** Enumerator that performs a merge join on its sorted inputs. - * Inputs must sorted in ascending order, nulls last. + * Inputs must be sorted in ascending order, nulls last. * * @param <TResult> result type * @param <TSource> left input record type @@ -4167,6 +4172,8 @@ public abstract class EnumerableDefaults { private final JoinType joinType; // key comparator, possibly null (Comparable#compareTo to be used in that case) private final @Nullable Comparator<TKey> comparator; + // key equality comparer, possibly null (equals to be used in that case) + private final @Nullable EqualityComparer<TKey> equalityComparer; private boolean done; private @Nullable Enumerator<TResult> results = null; // used for LEFT/ANTI join: if right input is over, all remaining elements from left are results @@ -4181,7 +4188,8 @@ public abstract class EnumerableDefaults { @Nullable Predicate2<TSource, TInner> extraPredicate, Function2<TSource, @Nullable TInner, TResult> resultSelector, JoinType joinType, - @Nullable Comparator<TKey> comparator) { + @Nullable Comparator<TKey> comparator, + @Nullable EqualityComparer<TKey> equalityComparer) { this.leftEnumerable = leftEnumerable; this.rightEnumerable = rightEnumerable; this.outerKeySelector = outerKeySelector; @@ -4190,6 +4198,7 @@ public abstract class EnumerableDefaults { this.resultSelector = resultSelector; this.joinType = joinType; this.comparator = comparator; + this.equalityComparer = equalityComparer; start(); } @@ -4258,15 +4267,18 @@ public abstract class EnumerableDefaults { results = Linq4j.emptyEnumerator(); } + /** Method to compare keys from left and right input (nulls must not be considered equal). */ private int compare(TKey key1, TKey key2) { - return comparator != null ? comparator.compare(key1, key2) : compareNullsLast(key1, key2); + return comparator != null + ? comparator.compare(key1, key2) + : compareNullsLastForMergeJoin(key1, key2); } - private int compareNullsLast(TKey v0, TKey v1) { - return v0 == v1 ? 0 - : v0 == null ? 1 - : v1 == null ? -1 - : v0.compareTo(v1); + /** Method to compare keys from the same input (nulls must be considered equal). */ + private boolean compareEquals(TKey key1, TKey key2) { + return equalityComparer != null + ? equalityComparer.equal(key1, key2) + : Objects.equals(key1, key2); } /** Moves to the next key that is present in both sides. Populates @@ -4291,7 +4303,14 @@ public abstract class EnumerableDefaults { done = true; return false; } - int c = compare(leftKey, rightKey); + int c; + try { + c = compare(leftKey, rightKey); + } catch (BothValuesAreNullException e) { + // consider the first (left) value as "bigger", to advance on the right value + // and continue with the algorithm + c = 1; + } if (c == 0) { break; } @@ -4390,13 +4409,7 @@ public abstract class EnumerableDefaults { // if we reach a null key, we are done (except LEFT join, that needs to process LHS fully) break; } - int c = compare(leftKey, leftKey2); - if (c != 0) { - if (c > 0) { - throw new IllegalStateException( - "mergeJoin assumes inputs sorted in ascending order, " + "however '" - + leftKey + "' is greater than '" + leftKey2 + "'"); - } + if (!compareEquals(leftKey, leftKey2)) { return true; } lefts.add(left); @@ -4423,13 +4436,7 @@ public abstract class EnumerableDefaults { // if we reach a null key, we are done break; } - int c = compare(rightKey, rightKey2); - if (c != 0) { - if (c > 0) { - throw new IllegalStateException( - "mergeJoin assumes input sorted in ascending order, " + "however '" - + rightKey + "' is greater than '" + rightKey2 + "'"); - } + if (!compareEquals(rightKey, rightKey2)) { return true; } rights.add(right); @@ -4495,6 +4502,36 @@ public abstract class EnumerableDefaults { } } + /** + * Exception used for control flow (it does not populate the stack trace to be more efficient) + * to signal that both values are <code>null</code>, so that the caller method (i.e. MergeJoin + * algorithm) can act accordingly. + */ + private static class BothValuesAreNullException extends RuntimeException { + @Override public synchronized Throwable fillInStackTrace() { + return this; + } + } + + public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, @Nullable Comparable v1) { + return compareNullsLastForMergeJoin(v0, v1, null); + } + + public static int compareNullsLastForMergeJoin(@Nullable Comparable v0, @Nullable Comparable v1, + @Nullable Comparator comparator) { + // Special code for mergeJoin algorithm: in case of two null values, they must not be + // considered as equal (otherwise the join would return incorrect results) + if (v0 == null && v1 == null) { + throw new BothValuesAreNullException(); + } + + //noinspection unchecked + return v0 == v1 ? 0 + : v0 == null ? 1 + : v1 == null ? -1 + : comparator == null ? v0.compareTo(v1) : comparator.compare(v0, v1); + } + /** Enumerates the elements of a cartesian product of two inputs. * * @param <TResult> result type