This is an automated email from the ASF dual-hosted git repository. rubenql pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/master by this push: new 4cd90f3 [CALCITE-3221] Add MergeUnion operator in Enumerable convention 4cd90f3 is described below commit 4cd90f36c3cf9a012e34f14129a907a4ce99c6f5 Author: rubenada <rube...@gmail.com> AuthorDate: Mon Jan 18 16:30:07 2021 +0000 [CALCITE-3221] Add MergeUnion operator in Enumerable convention --- .../adapter/enumerable/EnumerableMergeUnion.java | 118 +++++ .../enumerable/EnumerableMergeUnionRule.java | 105 +++++ .../adapter/enumerable/EnumerableRules.java | 6 + .../calcite/rel/metadata/RelMdCollation.java | 12 + .../java/org/apache/calcite/tools/Programs.java | 1 + .../org/apache/calcite/util/BuiltInMethod.java | 2 + .../apache/calcite/runtime/EnumerablesTest.java | 515 +++++++++++++++++++++ .../test/enumerable/EnumerableMergeUnionTest.java | 309 +++++++++++++ .../apache/calcite/linq4j/EnumerableDefaults.java | 37 +- .../calcite/linq4j/MergeUnionEnumerator.java | 208 +++++++++ 10 files changed, 1311 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java new file mode 100644 index 0000000..182e0a5 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.adapter.enumerable; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.linq4j.tree.BlockBuilder; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.linq4j.tree.ParameterExpression; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.util.BuiltInMethod; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; + +import java.util.ArrayList; +import java.util.List; + +/** Implementation of {@link org.apache.calcite.rel.core.Union} in + * {@link org.apache.calcite.adapter.enumerable.EnumerableConvention enumerable calling convention}. + * Performs a union (or union all) of all its inputs (which must be already sorted), + * respecting the order. */ +public class EnumerableMergeUnion extends EnumerableUnion { + + protected EnumerableMergeUnion(RelOptCluster cluster, RelTraitSet traitSet, List<RelNode> inputs, + boolean all) { + super(cluster, traitSet, inputs, all); + final RelCollation collation = traitSet.getCollation(); + if (collation == null || collation.getFieldCollations().isEmpty()) { + throw new IllegalArgumentException("EnumerableMergeUnion with no collation"); + } + for (RelNode input : inputs) { + final RelCollation inputCollation = input.getTraitSet().getCollation(); + if (inputCollation == null || !inputCollation.satisfies(collation)) { + throw new IllegalArgumentException("EnumerableMergeUnion input does not satisfy collation. " + + "EnumerableMergeUnion collation: " + collation + ". Input collation: " + + inputCollation + ". Input: " + input); + } + } + } + + public static EnumerableMergeUnion create(RelCollation collation, List<RelNode> inputs, + boolean all) { + final RelOptCluster cluster = inputs.get(0).getCluster(); + final RelTraitSet traitSet = cluster.traitSetOf(EnumerableConvention.INSTANCE).replace( + collation); + return new EnumerableMergeUnion(cluster, traitSet, inputs, all); + } + + @Override public EnumerableMergeUnion copy(RelTraitSet traitSet, List<RelNode> inputs, + boolean all) { + return new EnumerableMergeUnion(getCluster(), traitSet, inputs, all); + } + + @Override public Result implement(EnumerableRelImplementor implementor, Prefer pref) { + final BlockBuilder builder = new BlockBuilder(); + + final ParameterExpression inputListExp = Expressions.parameter( + List.class, + builder.newName("mergeUnionInputs" + getId())); + builder.add(Expressions.declare(0, inputListExp, Expressions.new_(ArrayList.class))); + + for (Ord<RelNode> ord : Ord.zip(inputs)) { + final EnumerableRel input = (EnumerableRel) ord.e; + final Result result = implementor.visitChild(this, ord.i, input, pref); + final Expression childExp = builder.append("child" + ord.i, result.block); + builder.add( + Expressions.statement( + Expressions.call(inputListExp, BuiltInMethod.COLLECTION_ADD.method, childExp))); + } + + final PhysType physType = PhysTypeImpl.of( + implementor.getTypeFactory(), + getRowType(), + pref.prefer(JavaRowFormat.CUSTOM)); + + final RelCollation collation = getTraitSet().getCollation(); + if (collation == null || collation.getFieldCollations().isEmpty()) { + // should not happen + throw new IllegalStateException("EnumerableMergeUnion with no collation"); + } + final Pair<Expression, Expression> pair = + physType.generateCollationKey(collation.getFieldCollations()); + final Expression sortKeySelector = pair.left; + final Expression sortComparator = pair.right; + + final Expression equalityComparator = Util.first( + physType.comparer(), + Expressions.call(BuiltInMethod.IDENTITY_COMPARER.method)); + + final Expression unionExp = Expressions.call( + BuiltInMethod.MERGE_UNION.method, + inputListExp, + sortKeySelector, + sortComparator, + Expressions.constant(all, boolean.class), + equalityComparator); + builder.add(unionExp); + + return implementor.result(physType, builder.toBlock()); + } +} diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnionRule.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnionRule.java new file mode 100644 index 0000000..be19d1f --- /dev/null +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnionRule.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.adapter.enumerable; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.core.Union; +import org.apache.calcite.rel.logical.LogicalSort; +import org.apache.calcite.rel.logical.LogicalUnion; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Rule to convert a {@link org.apache.calcite.rel.logical.LogicalSort} on top of a + * {@link org.apache.calcite.rel.logical.LogicalUnion} into a {@link EnumerableMergeUnion}. + * + * @see EnumerableRules#ENUMERABLE_MERGE_UNION_RULE + */ +public class EnumerableMergeUnionRule extends RelRule<EnumerableMergeUnionRule.Config> { + + /** Rule configuration. */ + public interface Config extends RelRule.Config { + Config DEFAULT_CONFIG = EMPTY.withDescription("EnumerableMergeUnionRule").withOperandSupplier( + b0 -> b0.operand(LogicalSort.class).oneInput( + b1 -> b1.operand(LogicalUnion.class).anyInputs())).as(Config.class); + + @Override default EnumerableMergeUnionRule toRule() { + return new EnumerableMergeUnionRule(this); + } + } + + public EnumerableMergeUnionRule(Config config) { + super(config); + } + + @Override public boolean matches(RelOptRuleCall call) { + final Sort sort = call.rel(0); + final RelCollation collation = sort.getCollation(); + if (collation == null || collation.getFieldCollations().isEmpty()) { + return false; + } + + final Union union = call.rel(1); + if (union.getInputs().size() < 2) { + return false; + } + + return true; + } + + @Override public void onMatch(RelOptRuleCall call) { + final Sort sort = call.rel(0); + final RelCollation collation = sort.getCollation(); + final Union union = call.rel(1); + final int unionInputsSize = union.getInputs().size(); + + // Push down sort limit, if possible. + RexNode inputFetch = null; + if (sort.fetch != null) { + if (sort.offset == null) { + inputFetch = sort.fetch; + } else if (sort.fetch instanceof RexLiteral && sort.offset instanceof RexLiteral) { + inputFetch = call.builder().literal( + RexLiteral.intValue(sort.fetch) + RexLiteral.intValue(sort.offset)); + } + } + + final List<RelNode> inputs = new ArrayList<>(unionInputsSize); + for (RelNode input : union.getInputs()) { + final RelNode newInput = sort.copy(sort.getTraitSet(), input, collation, null, inputFetch); + inputs.add( + convert(newInput, newInput.getTraitSet().replace(EnumerableConvention.INSTANCE))); + } + + RelNode result = EnumerableMergeUnion.create(sort.getCollation(), inputs, union.all); + + // If Sort contained a LIMIT / OFFSET, then put it back as an EnumerableLimit. + // The output of the MergeUnion is already sorted, so we do not need a sort anymore. + if (sort.offset != null || sort.fetch != null) { + result = EnumerableLimit.create(result, sort.offset, sort.fetch); + } + + call.transformTo(result); + } +} diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java index 8816923..17a7465 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRules.java @@ -103,6 +103,11 @@ public class EnumerableRules { EnumerableRepeatUnionRule.DEFAULT_CONFIG .toRule(EnumerableRepeatUnionRule.class); + /** Rule to convert a {@link org.apache.calcite.rel.logical.LogicalSort} on top of a + * {@link org.apache.calcite.rel.logical.LogicalUnion} into a {@link EnumerableMergeUnion}. */ + public static final EnumerableMergeUnionRule ENUMERABLE_MERGE_UNION_RULE = + EnumerableMergeUnionRule.Config.DEFAULT_CONFIG.toRule(); + /** Rule that converts a {@link LogicalTableSpool} into an * {@link EnumerableTableSpool}. */ @Experimental @@ -210,6 +215,7 @@ public class EnumerableRules { EnumerableRules.ENUMERABLE_LIMIT_RULE, EnumerableRules.ENUMERABLE_COLLECT_RULE, EnumerableRules.ENUMERABLE_UNCOLLECT_RULE, + EnumerableRules.ENUMERABLE_MERGE_UNION_RULE, EnumerableRules.ENUMERABLE_UNION_RULE, EnumerableRules.ENUMERABLE_REPEAT_UNION_RULE, EnumerableRules.ENUMERABLE_TABLE_SPOOL_RULE, diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java index 1fc2336..850f99a 100644 --- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java +++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdCollation.java @@ -19,6 +19,7 @@ package org.apache.calcite.rel.metadata; import org.apache.calcite.adapter.enumerable.EnumerableCorrelate; import org.apache.calcite.adapter.enumerable.EnumerableHashJoin; import org.apache.calcite.adapter.enumerable.EnumerableMergeJoin; +import org.apache.calcite.adapter.enumerable.EnumerableMergeUnion; import org.apache.calcite.adapter.enumerable.EnumerableNestedLoopJoin; import org.apache.calcite.adapter.jdbc.JdbcToEnumerableConverter; import org.apache.calcite.linq4j.Ord; @@ -175,6 +176,17 @@ public class RelMdCollation join.getJoinType())); } + public @Nullable ImmutableList<RelCollation> collations(EnumerableMergeUnion mergeUnion, + RelMetadataQuery mq) { + final RelCollation collation = mergeUnion.getTraitSet().getCollation(); + if (collation == null) { + // should not happen + return null; + } + // MergeUnion guarantees order, like a sort + return copyOf(RelMdCollation.sort(collation)); + } + public @Nullable ImmutableList<RelCollation> collations(EnumerableCorrelate join, RelMetadataQuery mq) { return copyOf( diff --git a/core/src/main/java/org/apache/calcite/tools/Programs.java b/core/src/main/java/org/apache/calcite/tools/Programs.java index c4f1429..b973c5c 100644 --- a/core/src/main/java/org/apache/calcite/tools/Programs.java +++ b/core/src/main/java/org/apache/calcite/tools/Programs.java @@ -81,6 +81,7 @@ public class Programs { EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_LIMIT_RULE, EnumerableRules.ENUMERABLE_UNION_RULE, + EnumerableRules.ENUMERABLE_MERGE_UNION_RULE, EnumerableRules.ENUMERABLE_INTERSECT_RULE, EnumerableRules.ENUMERABLE_MINUS_RULE, EnumerableRules.ENUMERABLE_TABLE_MODIFICATION_RULE, diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 84008a6..69a632d 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -239,6 +239,8 @@ public enum BuiltInMethod { CONCAT(ExtendedEnumerable.class, "concat", Enumerable.class), REPEAT_UNION(EnumerableDefaults.class, "repeatUnion", Enumerable.class, Enumerable.class, int.class, boolean.class, EqualityComparer.class), + MERGE_UNION(EnumerableDefaults.class, "mergeUnion", List.class, Function1.class, + Comparator.class, boolean.class, EqualityComparer.class), LAZY_COLLECTION_SPOOL(EnumerableDefaults.class, "lazyCollectionSpool", Collection.class, Enumerable.class), INTERSECT(ExtendedEnumerable.class, "intersect", Enumerable.class, boolean.class), diff --git a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java index f615d92..80d1a51 100644 --- a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java +++ b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java @@ -20,6 +20,7 @@ import org.apache.calcite.linq4j.Enumerable; import org.apache.calcite.linq4j.EnumerableDefaults; import org.apache.calcite.linq4j.JoinType; import org.apache.calcite.linq4j.Linq4j; +import org.apache.calcite.linq4j.function.EqualityComparer; import org.apache.calcite.linq4j.function.Function2; import org.apache.calcite.linq4j.function.Functions; import org.apache.calcite.linq4j.function.Predicate2; @@ -31,8 +32,10 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.List; import java.util.Locale; +import java.util.Objects; import static com.google.common.collect.Lists.newArrayList; @@ -966,6 +969,503 @@ class EnumerablesTest { + " null, Dept(30, Development)]")); } + @Test void testMergeUnionAllEmptyOnRight() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Lilly"), + new Emp(30, "Joe"), + new Emp(30, "Greg"))), + Linq4j.emptyEnumerable()), + e -> e.deptno, + INTEGER_ASC, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo("[Emp(20, Lilly), Emp(30, Joe), Emp(30, Greg)]")); + } + + @Test void testMergeUnionAllEmptyOnLeft() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.emptyEnumerable(), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Lilly"), + new Emp(30, "Joe"), + new Emp(30, "Greg")))), + e -> e.deptno, + INTEGER_ASC, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo("[Emp(20, Lilly), Emp(30, Joe), Emp(30, Greg)]")); + } + + @Test void testMergeUnionAllEmptyOnBoth() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.emptyEnumerable(), + Linq4j.emptyEnumerable()), + e -> e.deptno, + INTEGER_ASC, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo("[]")); + } + + @Test void testMergeUnionAllOrderByDeptAsc2inputs() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Lilly"), + new Emp(30, "Joe"), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(10, "Fred"), + new Emp(30, "Theodore"), + new Emp(40, "Sebastian")))), + e -> e.deptno, + INTEGER_ASC, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(10, Fred), Emp(20, Lilly), Emp(30, Joe), Emp(30, Greg), Emp(30, Theodore), Emp(40, Sebastian)]")); + } + + @Test void testMergeUnionAllOrderByDeptAsc3inputs() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Lilly"), + new Emp(30, "Joe"), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(15, "Phyllis"), + new Emp(18, "Maddie"), + new Emp(22, "Jenny"), + new Emp(42, "Susan"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(10, "Fred"), + new Emp(30, "Joe"), + new Emp(40, "Sebastian")))), + e -> e.deptno, + INTEGER_ASC, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(10, Fred), Emp(15, Phyllis), Emp(18, Maddie), Emp(20, Lilly), Emp(22, Jenny)," + + " Emp(30, Joe), Emp(30, Greg), Emp(30, Joe), Emp(40, Sebastian), Emp(42, Susan)]")); + } + + @Test void testMergeUnionOrderByDeptAsc3inputs() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(15, "Phyllis"), + new Emp(15, "Phyllis"), + new Emp(20, "Lilly"), + new Emp(30, "Joe"), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(15, "Phyllis"), + new Emp(18, "Maddie"), + new Emp(22, "Jenny"), + new Emp(30, "Joe"), + new Emp(42, "Susan"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(10, "Fred"), + new Emp(15, "Phyllis"), + new Emp(30, "Joe"), + new Emp(30, "Joe"), + new Emp(40, "Sebastian")))), + e -> e.deptno, + INTEGER_ASC, + false, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(10, Fred), Emp(15, Phyllis), Emp(18, Maddie), Emp(20, Lilly), Emp(22, Jenny)," + + " Emp(30, Joe), Emp(30, Greg), Emp(40, Sebastian), Emp(42, Susan)]")); + } + + @Test void testMergeUnionAllOrderByDeptDesc2inputs() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(42, "Lilly"), + new Emp(30, "Joe"), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(50, "Fred"), + new Emp(30, "Theodore"), + new Emp(10, "Sebastian")))), + e -> e.deptno, + INTEGER_DESC, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(50, Fred), Emp(42, Lilly), Emp(30, Joe), Emp(30, Greg), Emp(30, Theodore), Emp(10, Sebastian)]")); + } + + @Test void testMergeUnionAllOrderByDeptDesc3inputs() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(35, "Lilly"), + new Emp(22, "Jenny"), + new Emp(20, "Joe"), + new Emp(20, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(45, "Phyllis"), + new Emp(42, "Maddie"), + new Emp(22, "Jenny"), + new Emp(22, "Jenny"), + new Emp(12, "Susan"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(50, "Fred"), + new Emp(20, "Theodore"), + new Emp(15, "Sebastian")))), + e -> e.deptno, + INTEGER_DESC, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(50, Fred), Emp(45, Phyllis), Emp(42, Maddie), Emp(35, Lilly), Emp(22, Jenny)," + + " Emp(22, Jenny), Emp(22, Jenny), Emp(20, Joe), Emp(20, Greg), Emp(20, Theodore), Emp(15, Sebastian), Emp(12, Susan)]")); + } + + @Test void testMergeUnionOrderByDeptDesc3inputs() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(35, "Lilly"), + new Emp(22, "Jenny"), + new Emp(22, "Jenny"), + new Emp(20, "Joe"), + new Emp(20, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(45, "Phyllis"), + new Emp(42, "Maddie"), + new Emp(22, "Jenny"), + new Emp(12, "Susan"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(50, "Fred"), + new Emp(22, "Jenny"), + new Emp(20, "Theodore"), + new Emp(20, "Joe"), + new Emp(15, "Sebastian")))), + e -> e.deptno, + INTEGER_DESC, + false, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(50, Fred), Emp(45, Phyllis), Emp(42, Maddie), Emp(35, Lilly), Emp(22, Jenny)," + + " Emp(20, Joe), Emp(20, Greg), Emp(20, Theodore), Emp(15, Sebastian), Emp(12, Susan)]")); + } + + @Test void testMergeUnionAllOrderByNameAscNullsFirst() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(10, null), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(30, "Sebastian"), + new Emp(10, "Theodore")))), + e -> e.name, + STRING_ASC_NULLS_FIRST, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(20, null), Emp(10, null), Emp(20, null), Emp(30, Greg), Emp(30, Sebastian), Emp(10, Theodore)]")); + } + + @Test void testMergeUnionOrderByNameAscNullsFirst() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(10, null), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(30, "Sebastian"), + new Emp(10, "Theodore")))), + e -> e.name, + STRING_ASC_NULLS_FIRST, + false, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(20, null), Emp(10, null), Emp(30, Greg), Emp(30, Sebastian), Emp(10, Theodore)]")); + } + + @Test void testMergeUnionAllOrderByNameDescNullsFirst() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(10, null), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(30, "Theodore"), + new Emp(10, "Sebastian")))), + e -> e.name, + STRING_DESC_NULLS_FIRST, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(20, null), Emp(10, null), Emp(20, null), Emp(30, Theodore), Emp(10, Sebastian), Emp(30, Greg)]")); + } + + @Test void testMergeUnionOrderByNameDescNullsFirst() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(10, null), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(30, "Theodore"), + new Emp(10, "Sebastian")))), + e -> e.name, + STRING_DESC_NULLS_FIRST, + false, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(20, null), Emp(10, null), Emp(30, Theodore), Emp(10, Sebastian), Emp(30, Greg)]")); + } + + @Test void testMergeUnionAllOrderByNameAscNullsLast() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Greg"), + new Emp(10, null), + new Emp(30, null))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Greg"), + new Emp(30, "Sebastian"), + new Emp(30, "Theodore"), + new Emp(10, null)))), + e -> e.name, + STRING_ASC_NULLS_LAST, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(20, Greg), Emp(20, Greg), Emp(30, Sebastian), Emp(30, Theodore), Emp(10, null), Emp(30, null), Emp(10, null)]")); + } + + @Test void testMergeUnionOrderByNameAscNullsLast() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Greg"), + new Emp(10, null), + new Emp(30, null))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Greg"), + new Emp(30, "Sebastian"), + new Emp(30, "Theodore"), + new Emp(10, null)))), + e -> e.name, + STRING_ASC_NULLS_LAST, + false, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(20, Greg), Emp(30, Sebastian), Emp(30, Theodore), Emp(10, null), Emp(30, null)]")); + } + + @Test void testMergeUnionAllOrderByNameDescNullsLast() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Greg"), + new Emp(10, null), + new Emp(30, null))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(30, "Theodore"), + new Emp(30, "Sebastian"), + new Emp(20, "Greg"), + new Emp(10, null)))), + e -> e.name, + STRING_DESC_NULLS_LAST, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(30, Theodore), Emp(30, Sebastian), Emp(20, Greg), Emp(20, Greg), Emp(10, null), Emp(30, null), Emp(10, null)]")); + } + + @Test void testMergeUnionOrderByNameDescNullsLast() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Greg"), + new Emp(10, null), + new Emp(30, null))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(30, "Theodore"), + new Emp(30, "Sebastian"), + new Emp(20, "Greg"), + new Emp(10, null)))), + e -> e.name, + STRING_DESC_NULLS_LAST, + false, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(30, Theodore), Emp(30, Sebastian), Emp(20, Greg), Emp(10, null), Emp(30, null)]")); + } + + @Test void testMergeUnionAllOrderByDeptAscNameDescNullsFirst() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList(new Emp(10, null))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Lilly"), + new Emp(20, "Lilly"), + new Emp(20, "Antoine"), + new Emp(22, null), + new Emp(30, "Joe"), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(20, "Annie"), + new Emp(22, "Jenny"), + new Emp(42, "Susan"))), + Linq4j.asEnumerable( + Arrays.asList(new Emp(50, "Lolly"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(10, "Fred"), + new Emp(20, "Lilly"), + new Emp(22, null), + new Emp(30, "Joe"), + new Emp(40, "Sebastian")))), + e -> e, + DEPT_ASC_AND_NAME_DESC_NULLS_FIRST, + true, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(10, null), Emp(10, Fred), Emp(20, null), Emp(20, Lilly), Emp(20, Lilly), Emp(20, Lilly)," + + " Emp(20, Antoine), Emp(20, Annie), Emp(22, null), Emp(22, null), Emp(22, Jenny)," + + " Emp(30, Joe), Emp(30, Joe), Emp(30, Greg), Emp(40, Sebastian), Emp(42, Susan), Emp(50, Lolly)]")); + } + + @Test void testMergeUnionOrderByDeptAscNameDescNullsFirst() { + assertThat( + EnumerableDefaults.mergeUnion( + Arrays.asList( + Linq4j.asEnumerable( + Arrays.asList(new Emp(10, null))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, "Lilly"), + new Emp(20, "Lilly"), + new Emp(20, "Antoine"), + new Emp(22, null), + new Emp(30, "Joe"), + new Emp(30, "Greg"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(20, null), + new Emp(20, "Annie"), + new Emp(22, "Jenny"), + new Emp(42, "Susan"))), + Linq4j.asEnumerable( + Arrays.asList(new Emp(50, "Lolly"))), + Linq4j.asEnumerable( + Arrays.asList( + new Emp(10, "Fred"), + new Emp(20, "Lilly"), + new Emp(22, null), + new Emp(30, "Joe"), + new Emp(40, "Sebastian")))), + e -> e, + DEPT_ASC_AND_NAME_DESC_NULLS_FIRST, + false, + EMP_EQUALITY_COMPARER).toList().toString(), + equalTo( + "[Emp(10, null), Emp(10, Fred), Emp(20, null), Emp(20, Lilly)," + + " Emp(20, Antoine), Emp(20, Annie), Emp(22, null), Emp(22, Jenny)," + + " Emp(30, Joe), Emp(30, Greg), Emp(40, Sebastian), Emp(42, Susan), Emp(50, Lolly)]")); + } + + private static final Comparator<Integer> INTEGER_ASC = Integer::compare; + private static final Comparator<Integer> INTEGER_DESC = INTEGER_ASC.reversed(); + + private static final Comparator<String> STRING_ASC = Comparator.naturalOrder(); + private static final Comparator<String> STRING_DESC = STRING_ASC.reversed(); + + private static final Comparator<String> STRING_ASC_NULLS_FIRST = + Comparator.nullsFirst(STRING_ASC); + private static final Comparator<String> STRING_ASC_NULLS_LAST = + Comparator.nullsLast(STRING_ASC); + private static final Comparator<String> STRING_DESC_NULLS_FIRST = + Comparator.nullsFirst(STRING_DESC); + private static final Comparator<String> STRING_DESC_NULLS_LAST = + Comparator.nullsLast(STRING_DESC); + + private static final Comparator<Emp> DEPT_ASC_AND_NAME_DESC_NULLS_FIRST = + Comparator.<Emp>comparingInt(emp -> emp.deptno) + .thenComparing(emp -> emp.name, STRING_DESC_NULLS_FIRST); + + private static final EqualityComparer<Emp> EMP_EQUALITY_COMPARER = Functions.identityComparer(); + /** Employee record. */ private static class Emp { final int deptno; @@ -976,6 +1476,21 @@ class EnumerablesTest { this.name = name; } + @Override public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || this.getClass() != o.getClass()) { + return false; + } + final Emp emp = (Emp) o; + return this.deptno == emp.deptno && Objects.equals(this.name, emp.name); + } + + @Override public int hashCode() { + return Objects.hash(this.deptno, this.name); + } + @Override public String toString() { return "Emp(" + deptno + ", " + name + ")"; } diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableMergeUnionTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableMergeUnionTest.java new file mode 100644 index 0000000..483c0fc --- /dev/null +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableMergeUnionTest.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.test.enumerable; + +import org.apache.calcite.adapter.enumerable.EnumerableRules; +import org.apache.calcite.adapter.java.ReflectiveSchema; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.config.Lex; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.runtime.Hook; +import org.apache.calcite.test.CalciteAssert; +import org.apache.calcite.test.JdbcTest; + +import org.junit.jupiter.api.Test; + +import java.util.function.Consumer; + +/** + * Unit test for + * {@link org.apache.calcite.adapter.enumerable.EnumerableMergeUnion}. + */ +class EnumerableMergeUnionTest { + + @Test void mergeUnionAllOrderByEmpid() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select empid, name from emps where name like 'G%' union all select empid, name from emps where name like '%l') order by empid") + .explainContains("EnumerableMergeUnion(all=[true])\n" + + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "empid=1; name=Bill", + "empid=6; name=Guy", + "empid=10; name=Gabriel", + "empid=10; name=Gabriel", + "empid=12; name=Paul", + "empid=29; name=Anibal", + "empid=40; name=Emmanuel", + "empid=45; name=Pascal"); + } + + @Test void mergeUnionOrderByEmpid() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select empid, name from emps where name like 'G%' union select empid, name from emps where name like '%l') order by empid") + .explainContains("EnumerableMergeUnion(all=[false])\n" + + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "empid=1; name=Bill", + "empid=6; name=Guy", + "empid=10; name=Gabriel", + "empid=12; name=Paul", + "empid=29; name=Anibal", + "empid=40; name=Emmanuel", + "empid=45; name=Pascal"); + } + + @Test void mergeUnionAllOrderByName() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select empid, name from emps where name like 'G%' union all select empid, name from emps where name like '%l') order by name") + .explainContains("EnumerableMergeUnion(all=[true])\n" + + " EnumerableSort(sort0=[$1], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$1], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "empid=29; name=Anibal", + "empid=1; name=Bill", + "empid=40; name=Emmanuel", + "empid=10; name=Gabriel", + "empid=10; name=Gabriel", + "empid=6; name=Guy", + "empid=45; name=Pascal", + "empid=12; name=Paul"); + } + + @Test void mergeUnionOrderByName() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select empid, name from emps where name like 'G%' union select empid, name from emps where name like '%l') order by name") + .explainContains("EnumerableMergeUnion(all=[false])\n" + + " EnumerableSort(sort0=[$1], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$1], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "empid=29; name=Anibal", + "empid=1; name=Bill", + "empid=40; name=Emmanuel", + "empid=10; name=Gabriel", + "empid=6; name=Guy", + "empid=45; name=Pascal", + "empid=12; name=Paul"); + } + + @Test void mergeUnionSingleColumnOrderByName() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select name from emps where name like 'G%' union select name from emps where name like '%l') order by name") + .explainContains("EnumerableMergeUnion(all=[false])\n" + + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$0], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "name=Anibal", + "name=Bill", + "name=Emmanuel", + "name=Gabriel", + "name=Guy", + "name=Pascal", + "name=Paul"); + } + + @Test void mergeUnionOrderByNameWithLimit() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select empid, name from emps where name like 'G%' union select empid, name from emps where name like '%l') order by name limit 3") + .explainContains("EnumerableLimit(fetch=[3])\n" + + " EnumerableMergeUnion(all=[false])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], name=[$t2])\n" + + " EnumerableLimitSort(sort0=[$2], dir0=[ASC], fetch=[3])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], proj#0..4=[{exprs}], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], name=[$t2])\n" + + " EnumerableLimitSort(sort0=[$2], dir0=[ASC], fetch=[3])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], proj#0..4=[{exprs}], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "empid=29; name=Anibal", + "empid=1; name=Bill", + "empid=40; name=Emmanuel"); + } + + @Test void mergeUnionOrderByNameWithOffset() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select empid, name from emps where name like 'G%' union select empid, name from emps where name like '%l') order by name offset 2") + .explainContains("EnumerableLimit(offset=[2])\n" + + " EnumerableMergeUnion(all=[false])\n" + + " EnumerableSort(sort0=[$1], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$1], dir0=[ASC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], empid=[$t0], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "empid=40; name=Emmanuel", + "empid=10; name=Gabriel", + "empid=6; name=Guy", + "empid=45; name=Pascal", + "empid=12; name=Paul"); + } + + @Test void mergeUnionOrderByNameWithLimitAndOffset() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select empid, name from emps where name like 'G%' union select empid, name from emps where name like '%l') order by name limit 3 offset 2") + .explainContains("EnumerableLimit(offset=[2], fetch=[3])\n" + + " EnumerableMergeUnion(all=[false])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], name=[$t2])\n" + + " EnumerableLimitSort(sort0=[$2], dir0=[ASC], fetch=[5])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['G%'], expr#6=[LIKE($t2, $t5)], proj#0..4=[{exprs}], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], empid=[$t0], name=[$t2])\n" + + " EnumerableLimitSort(sort0=[$2], dir0=[ASC], fetch=[5])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%l'], expr#6=[LIKE($t2, $t5)], proj#0..4=[{exprs}], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "empid=40; name=Emmanuel", + "empid=10; name=Gabriel", + "empid=6; name=Guy"); + } + + @Test void mergeUnionAllOrderByCommissionAscNullsFirstAndNameDesc() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select commission, name from emps where name like 'R%' union all select commission, name from emps where name like '%y%') order by commission asc nulls first, name desc") + .explainContains("EnumerableMergeUnion(all=[true])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['R%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%y%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "commission=null; name=Taylor", + "commission=null; name=Riyad", + "commission=null; name=Riyad", + "commission=null; name=Ralf", + "commission=250; name=Seohyun", + "commission=250; name=Hyuna", + "commission=250; name=Andy", + "commission=500; name=Kylie", + "commission=500; name=Guy"); + } + + @Test void mergeUnionOrderByCommissionAscNullsFirstAndNameDesc() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select commission, name from emps where name like 'R%' union select commission, name from emps where name like '%y%') order by commission asc nulls first, name desc") + .explainContains("EnumerableMergeUnion(all=[false])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['R%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%y%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "commission=null; name=Taylor", + "commission=null; name=Riyad", + "commission=null; name=Ralf", + "commission=250; name=Seohyun", + "commission=250; name=Hyuna", + "commission=250; name=Andy", + "commission=500; name=Kylie", + "commission=500; name=Guy"); + } + + @Test void mergeUnionAllOrderByCommissionAscNullsLastAndNameDesc() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select commission, name from emps where name like 'R%' union all select commission, name from emps where name like '%y%') order by commission asc nulls last, name desc") + .explainContains("EnumerableMergeUnion(all=[true])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['R%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%y%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "commission=250; name=Seohyun", + "commission=250; name=Hyuna", + "commission=250; name=Andy", + "commission=500; name=Kylie", + "commission=500; name=Guy", + "commission=null; name=Taylor", + "commission=null; name=Riyad", + "commission=null; name=Riyad", + "commission=null; name=Ralf"); + } + + @Test void mergeUnionOrderByCommissionAscNullsLastAndNameDesc() { + tester(false, + new JdbcTest.HrSchemaBig(), + "select * from (select commission, name from emps where name like 'R%' union select commission, name from emps where name like '%y%') order by commission asc nulls last, name desc") + .explainContains("EnumerableMergeUnion(all=[false])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['R%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n" + + " EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC])\n" + + " EnumerableCalc(expr#0..4=[{inputs}], expr#5=['%y%'], expr#6=[LIKE($t2, $t5)], commission=[$t4], name=[$t2], $condition=[$t6])\n" + + " EnumerableTableScan(table=[[s, emps]])\n") + .returnsOrdered( + "commission=250; name=Seohyun", + "commission=250; name=Hyuna", + "commission=250; name=Andy", + "commission=500; name=Kylie", + "commission=500; name=Guy", + "commission=null; name=Taylor", + "commission=null; name=Riyad", + "commission=null; name=Ralf"); + } + + private CalciteAssert.AssertQuery tester(boolean forceDecorrelate, + Object schema, String sqlQuery) { + return CalciteAssert.that() + .with(CalciteConnectionProperty.LEX, Lex.JAVA) + .with(CalciteConnectionProperty.FORCE_DECORRELATE, forceDecorrelate) + .withSchema("s", new ReflectiveSchema(schema)) + .query(sqlQuery) + .withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) planner -> { + // Force UNION to be implemented via EnumerableMergeUnion + planner.removeRule(EnumerableRules.ENUMERABLE_UNION_RULE); + // Allow EnumerableLimitSort optimization + planner.addRule(EnumerableRules.ENUMERABLE_LIMIT_SORT_RULE); + }); + } +} diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index 7a7ff63..4f8271d 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -3674,7 +3674,7 @@ public abstract class EnumerableDefaults { return a0 -> a0.element; } - private static <TSource> Function1<TSource, Wrapped<TSource>> wrapperFor( + static <TSource> Function1<TSource, Wrapped<TSource>> wrapperFor( final EqualityComparer<TSource> comparer) { return a0 -> Wrapped.upAs(comparer, a0); } @@ -3997,7 +3997,7 @@ public abstract class EnumerableDefaults { /** Value wrapped with a comparer. * * @param <T> element type */ - private static class Wrapped<T> { + static class Wrapped<T> { private final EqualityComparer<T> comparer; private final T element; @@ -4683,4 +4683,37 @@ public abstract class EnumerableDefaults { } }; } + + /** + * Merge Union Enumerable. + * Performs a union (or union all) of all its inputs (which must be already sorted), + * respecting the order. + * + * @param sources input enumerables (must be already sorted) + * @param sortKeySelector sort key selector + * @param sortComparator sort comparator to decide the next item + * @param all whether duplicates will be considered or not + * @param equalityComparer {@link EqualityComparer} to control duplicates, + * only used if {@code all} is {@code false} + * @param <TSource> record type + * @param <TKey> sort key + */ + public static <TSource, TKey> Enumerable<TSource> mergeUnion( + List<Enumerable<TSource>> sources, + Function1<TSource, TKey> sortKeySelector, + Comparator<TKey> sortComparator, + boolean all, + EqualityComparer<TSource> equalityComparer) { + return new AbstractEnumerable<TSource>() { + @Override public Enumerator<TSource> enumerator() { + return new MergeUnionEnumerator<>( + sources, + sortKeySelector, + sortComparator, + all, + equalityComparer); + } + }; + } + } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/MergeUnionEnumerator.java b/linq4j/src/main/java/org/apache/calcite/linq4j/MergeUnionEnumerator.java new file mode 100644 index 0000000..841c34f --- /dev/null +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/MergeUnionEnumerator.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.linq4j; + +import org.apache.calcite.linq4j.function.EqualityComparer; +import org.apache.calcite.linq4j.function.Function1; + +import org.checkerframework.checker.initialization.qual.UnknownInitialization; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.RequiresNonNull; + +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * Performs a union (or union all) of all its inputs (which must be already sorted), + * respecting the order. + * @param <TSource> record type + * @param <TKey> sort key + */ +final class MergeUnionEnumerator<TSource, TKey> implements Enumerator<TSource> { + private final Enumerator<TSource>[] inputs; + private final TSource[] currentInputsValues; + private final boolean[] inputsFinished; + private final Function1<TSource, TKey> sortKeySelector; + private final Comparator<TKey> sortComparator; + private TSource currentValue; + private int activeInputs; + + // Set to control duplicates, only used if "all" is false + private final @Nullable Set<EnumerableDefaults.Wrapped<TSource>> processed; + private final @Nullable Function1<TSource, EnumerableDefaults.Wrapped<TSource>> wrapper; + private @Nullable TKey currentKeyInProcessedSet; + + private static final Object NOT_INIT = new Object(); + + MergeUnionEnumerator( + List<Enumerable<TSource>> sources, + Function1<TSource, TKey> sortKeySelector, + Comparator<TKey> sortComparator, + boolean all, + EqualityComparer<TSource> equalityComparer) { + this.sortKeySelector = sortKeySelector; + this.sortComparator = sortComparator; + + if (all) { + this.processed = null; + this.wrapper = null; + } else { + this.processed = new HashSet<>(); + this.wrapper = EnumerableDefaults.wrapperFor(equalityComparer); + } + + final int size = sources.size(); + //noinspection unchecked + this.inputs = new Enumerator[size]; + int i = 0; + for (Enumerable<TSource> source : sources) { + this.inputs[i++] = source.enumerator(); + } + + //noinspection unchecked + this.currentInputsValues = (TSource[]) new Object[size]; + this.activeInputs = this.currentInputsValues.length; + this.inputsFinished = new boolean[size]; + //noinspection unchecked + this.currentValue = (TSource) NOT_INIT; + + initEnumerators(); + } + + @RequiresNonNull("inputs") + @SuppressWarnings("method.invocation.invalid") + private void initEnumerators(@UnknownInitialization MergeUnionEnumerator<TSource, TKey> this) { + for (int i = 0; i < inputs.length; i++) { + moveEnumerator(i); + } + } + + private void moveEnumerator(int i) { + final Enumerator<TSource> enumerator = inputs[i]; + if (!enumerator.moveNext()) { + activeInputs--; + inputsFinished[i] = true; + @Nullable TSource[] auxInputsValues = currentInputsValues; + auxInputsValues[i] = null; + } else { + currentInputsValues[i] = enumerator.current(); + inputsFinished[i] = false; + } + } + + private boolean checkNotDuplicated(TSource value) { + if (processed == null) { + return true; // UNION ALL: no need to check duplicates + } + + // check duplicates + @SuppressWarnings("dereference.of.nullable") + final EnumerableDefaults.Wrapped<TSource> wrapped = wrapper.apply(value); + if (!processed.contains(wrapped)) { + final TKey key = sortKeySelector.apply(value); + if (!processed.isEmpty()) { + // Since inputs are sorted, we do not need to keep in the set all the items that we + // have previously returned, just the ones with the same key, as soon as we see a new + // key, we can clear the set containing the items belonging to the previous key + @SuppressWarnings("argument.type.incompatible") + final int sortComparison = sortComparator.compare(key, currentKeyInProcessedSet); + if (sortComparison != 0) { + processed.clear(); + currentKeyInProcessedSet = key; + } + } else { + currentKeyInProcessedSet = key; + } + processed.add(wrapped); + return true; + } + return false; + } + + private int compare(TSource e1, TSource e2) { + final TKey key1 = sortKeySelector.apply(e1); + final TKey key2 = sortKeySelector.apply(e2); + return sortComparator.compare(key1, key2); + } + + @Override public TSource current() { + if (currentValue == NOT_INIT) { + throw new NoSuchElementException(); + } + return currentValue; + } + + @Override public boolean moveNext() { + while (activeInputs > 0) { + int candidateIndex = -1; + for (int i = 0; i < currentInputsValues.length; i++) { + if (!inputsFinished[i]) { + candidateIndex = i; + break; + } + } + + if (activeInputs > 1) { + for (int i = candidateIndex + 1; i < currentInputsValues.length; i++) { + if (inputsFinished[i]) { + continue; + } + + final int comp = compare( + currentInputsValues[candidateIndex], + currentInputsValues[i]); + if (comp > 0) { + candidateIndex = i; + } + } + } + + if (checkNotDuplicated(currentInputsValues[candidateIndex])) { + currentValue = currentInputsValues[candidateIndex]; + moveEnumerator(candidateIndex); + return true; + } else { + moveEnumerator(candidateIndex); + // continue loop + } + } + return false; + } + + @Override public void reset() { + for (Enumerator<TSource> enumerator : inputs) { + enumerator.reset(); + } + if (processed != null) { + processed.clear(); + currentKeyInProcessedSet = null; + } + //noinspection unchecked + currentValue = (TSource) NOT_INIT; + activeInputs = currentInputsValues.length; + initEnumerators(); + } + + @Override public void close() { + for (Enumerator<TSource> enumerator : inputs) { + enumerator.close(); + } + } +}