This is an automated email from the ASF dual-hosted git repository. rubenql pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/master by this push: new bfde14b [CALCITE-3828] MergeJoin throws NPE in case of null keys bfde14b is described below commit bfde14be1284efc6d4560868fef3724238c35dc3 Author: rubenada <rube...@gmail.com> AuthorDate: Wed Feb 26 10:37:47 2020 +0100 [CALCITE-3828] MergeJoin throws NPE in case of null keys --- .../apache/calcite/runtime/EnumerablesTest.java | 100 ++++++++++++++++----- .../apache/calcite/linq4j/EnumerableDefaults.java | 20 ++++- 2 files changed, 95 insertions(+), 25 deletions(-) diff --git a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java index dc1748c..1731653 100644 --- a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java +++ b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java @@ -122,52 +122,104 @@ public class EnumerablesTest { + " Emp(30, Greg), Dept(30, Development)]")); } + @Test public void testMergeJoinWithNullKeys() { + assertThat( + EnumerableDefaults.mergeJoin( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(30, "Fred"), + new Emp(20, "Sebastian"), + new Emp(30, "Theodore"), + new Emp(20, "Theodore"), + new Emp(40, null), + new Emp(30, null))), + Linq4j.asEnumerable( + Arrays.asList( + new Dept(15, "Marketing"), + new Dept(20, "Sales"), + new Dept(30, "Theodore"), + new Dept(40, null))), + e -> e.name, + d -> d.name, + (v0, v1) -> v0 + ", " + v1, false, false).toList().toString(), + equalTo("[Emp(30, Theodore), Dept(30, Theodore)," + + " Emp(20, Theodore), Dept(30, Theodore)]")); + } + @Test public void testMergeJoin2() { // Matching keys at start - assertThat( - intersect(Lists.newArrayList(1, 3, 4), - Lists.newArrayList(1, 4)).toList().toString(), + testIntersect( + Lists.newArrayList(1, 3, 4), + Lists.newArrayList(1, 4), equalTo("[1, 4]")); // Matching key at start and end of right, not of left - assertThat( - intersect(Lists.newArrayList(0, 1, 3, 4, 5), - Lists.newArrayList(1, 4)).toList().toString(), + testIntersect( + Lists.newArrayList(0, 1, 3, 4, 5), + Lists.newArrayList(1, 4), equalTo("[1, 4]")); // Matching key at start and end of left, not right - assertThat( - intersect(Lists.newArrayList(1, 3, 4), - Lists.newArrayList(0, 1, 4, 5)).toList().toString(), + testIntersect( + Lists.newArrayList(1, 3, 4), + Lists.newArrayList(0, 1, 4, 5), equalTo("[1, 4]")); // Matching key not at start or end of left or right - assertThat( - intersect(Lists.newArrayList(0, 2, 3, 4, 5), - Lists.newArrayList(1, 3, 4, 6)).toList().toString(), + testIntersect( + Lists.newArrayList(0, 2, 3, 4, 5), + Lists.newArrayList(1, 3, 4, 6), equalTo("[3, 4]")); } @Test public void testMergeJoin3() { // No overlap - assertThat( - intersect(Lists.newArrayList(0, 2, 4), - Lists.newArrayList(1, 3, 5)).toList().toString(), + testIntersect( + Lists.newArrayList(0, 2, 4), + Lists.newArrayList(1, 3, 5), equalTo("[]")); // Left empty - assertThat( - intersect(new ArrayList<>(), - newArrayList(1, 3, 4, 6)).toList().toString(), + testIntersect( + new ArrayList<>(), + newArrayList(1, 3, 4, 6), equalTo("[]")); // Right empty - assertThat( - intersect(newArrayList(3, 7), - new ArrayList<>()).toList().toString(), + testIntersect( + newArrayList(3, 7), + new ArrayList<>(), equalTo("[]")); // Both empty - assertThat( - intersect(new ArrayList<Integer>(), - new ArrayList<>()).toList().toString(), + testIntersect( + new ArrayList<Integer>(), + new ArrayList<>(), equalTo("[]")); } + private static <T extends Comparable<T>> void testIntersect( + List<T> list0, List<T> list1, org.hamcrest.Matcher<String> matcher) { + assertThat( + intersect(list0, list1).toList().toString(), + matcher); + + // Repeat test with nulls at the end of left / right: result should not be impacted + + // Null at the end of left + list0.add(null); + assertThat( + intersect(list0, list1).toList().toString(), + matcher); + + // Null at the end of right + list0.remove(list0.size() - 1); + list1.add(null); + assertThat( + intersect(list0, list1).toList().toString(), + matcher); + + // Null at the end of left and right + list0.add(null); + assertThat( + intersect(list0, list1).toList().toString(), + matcher); + } + private static <T extends Comparable<T>> Enumerable<T> intersect( List<T> list0, List<T> list1) { return EnumerableDefaults.mergeJoin( diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index a20f5fa..6e35699 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -1959,7 +1959,10 @@ public abstract class EnumerableDefaults { }; } - /** Joins two inputs that are sorted on the key. */ + /** + * Joins two inputs that are sorted on the key. + * Inputs must sorted in ascending order, nulls last. + */ public static <TSource, TInner, TKey extends Comparable<TKey>, TResult> Enumerable<TResult> mergeJoin(final Enumerable<TSource> outer, final Enumerable<TInner> inner, @@ -3785,6 +3788,7 @@ public abstract class EnumerableDefaults { } /** Enumerator that performs a merge join on its sorted inputs. + * Inputs must sorted in ascending order, nulls last. * * @param <TResult> result type * @param <TSource> left input record type @@ -3833,6 +3837,12 @@ public abstract class EnumerableDefaults { TInner right = rightEnumerator.current(); TKey rightKey = innerKeySelector.apply(right); for (;;) { + // mergeJoin assumes inputs sorted in ascending order with nulls last, + // if we reach a null key, we are done. + if (leftKey == null || rightKey == null) { + done = true; + return false; + } int c = leftKey.compareTo(rightKey); if (c == 0) { break; @@ -3862,6 +3872,10 @@ public abstract class EnumerableDefaults { } left = leftEnumerator.current(); TKey leftKey2 = outerKeySelector.apply(left); + if (leftKey2 == null) { + done = true; + break; + } int c = leftKey.compareTo(leftKey2); if (c != 0) { if (c > 0) { @@ -3882,6 +3896,10 @@ public abstract class EnumerableDefaults { } right = rightEnumerator.current(); TKey rightKey2 = innerKeySelector.apply(right); + if (rightKey2 == null) { + done = true; + break; + } int c = rightKey.compareTo(rightKey2); if (c != 0) { if (c > 0) {