This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 0cce229903a845a7b8ed36cf86d6078fd82d73d3 Author: Julian Hyde <[email protected]> AuthorDate: Mon Jun 24 13:01:37 2019 -0700 [CALCITE-3145] RelBuilder.aggregate throws IndexOutOfBoundsException if groupKey is non-empty and there are duplicate aggregate functions The cause is that [CALCITE-3123] did not handle the case of non-empty groupKey. Enable RelBuilder.Config.dedupAggregateCalls by default. --- .../java/org/apache/calcite/tools/RelBuilder.java | 53 +++++++++++----------- .../org/apache/calcite/test/RelBuilderTest.java | 47 +++++++++++++++---- 2 files changed, 64 insertions(+), 36 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index 079c3de..f19a510 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -1634,32 +1634,33 @@ public class RelBuilder { if (!config.dedupAggregateCalls || Util.isDistinct(aggregateCalls)) { return aggregate_(groupSet, groupSets, r, aggregateCalls, registrar.extraNodes, frame.fields); - } else { - // There are duplicate aggregate calls. - final Set<AggregateCall> callSet = new HashSet<>(); - final List<Integer> projects = - new ArrayList<>(Util.range(groupSet.cardinality())); - final List<AggregateCall> distinctAggregateCalls = new ArrayList<>(); - for (AggregateCall aggregateCall : aggregateCalls) { - final int i; - if (callSet.add(aggregateCall)) { - i = distinctAggregateCalls.size(); - distinctAggregateCalls.add(aggregateCall); - } else { - i = distinctAggregateCalls.indexOf(aggregateCall); - assert i >= 0; - } - projects.add(i); - } - aggregate_(groupSet, groupSets, r, distinctAggregateCalls, - registrar.extraNodes, frame.fields); - final List<RexNode> fields = - new ArrayList<>(fields(Util.range(groupSet.cardinality()))); - for (Ord<Integer> project : Ord.zip(projects)) { - fields.add(alias(field(project.e), aggregateCalls.get(project.i).name)); + } + + // There are duplicate aggregate calls. Rebuild the list to eliminate + // duplicates, then add a Project. + final Set<AggregateCall> callSet = new HashSet<>(); + final List<Pair<Integer, String>> projects = new ArrayList<>(); + Util.range(groupSet.cardinality()) + .forEach(i -> projects.add(Pair.of(i, null))); + final List<AggregateCall> distinctAggregateCalls = new ArrayList<>(); + for (AggregateCall aggregateCall : aggregateCalls) { + final int i; + if (callSet.add(aggregateCall)) { + i = distinctAggregateCalls.size(); + distinctAggregateCalls.add(aggregateCall); + } else { + i = distinctAggregateCalls.indexOf(aggregateCall); + assert i >= 0; } - return project(fields); + projects.add(Pair.of(groupSet.cardinality() + i, aggregateCall.name)); } + aggregate_(groupSet, groupSets, r, distinctAggregateCalls, + registrar.extraNodes, frame.fields); + final List<RexNode> fields = projects.stream() + .map(p -> p.right == null ? field(p.left) + : alias(field(p.left), p.right)) + .collect(Collectors.toList()); + return project(fields); } /** Finishes the implementation of {@link #aggregate} by creating an @@ -2787,10 +2788,10 @@ public class RelBuilder { public static class Config { /** Default configuration. */ public static final Config DEFAULT = - new Config(false, true); + new Config(true, true); /** Whether {@link RelBuilder#aggregate} should eliminate duplicate - * aggregate calls; default true but currently disabled. */ + * aggregate calls; default true. */ public final boolean dedupAggregateCalls; /** Whether to simplify expressions; default true. */ diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index 1178a30..1f8f844 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -967,14 +967,6 @@ public class RelBuilderTest { /** Tests that {@link RelBuilder#aggregate} eliminates duplicate aggregate * calls and creates a {@code Project} to compensate. */ @Test public void testAggregateEliminatesDuplicateCalls() { - final Function<RelBuilder, RelNode> fn = builder -> - builder.scan("EMP") - .aggregate(builder.groupKey(), - builder.sum(builder.field(1)).as("S1"), - builder.count().as("C"), - builder.sum(builder.field(2)).as("S2"), - builder.sum(builder.field(1)).as("S1b")) - .build(); final RelBuilder builder = createBuilder(configBuilder -> configBuilder.withDedupAggregateCalls(true)); @@ -982,7 +974,7 @@ public class RelBuilderTest { + "LogicalProject(S1=[$0], C=[$1], S2=[$2], S1b=[$0])\n" + " LogicalAggregate(group=[{}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; - assertThat(fn.apply(builder), hasTree(expected)); + assertThat(buildRelWithDuplicateAggregates(builder), hasTree(expected)); // Now, disable the rewrite final RelBuilder builder2 = @@ -991,7 +983,42 @@ public class RelBuilderTest { final String expected2 = "" + "LogicalAggregate(group=[{}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)], S1b=[SUM($1)])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; - assertThat(fn.apply(builder2), hasTree(expected2)); + assertThat(buildRelWithDuplicateAggregates(builder2), hasTree(expected2)); + } + + /** As {@link #testAggregateEliminatesDuplicateCalls()} but with a + * single-column GROUP BY clause. */ + @Test public void testAggregateEliminatesDuplicateCalls2() { + final RelBuilder builder = RelBuilder.create(config().build()); + RelNode root = buildRelWithDuplicateAggregates(builder, 0); + final String expected = "" + + "LogicalProject(EMPNO=[$0], S1=[$1], C=[$2], S2=[$3], S1b=[$1])\n" + + " LogicalAggregate(group=[{0}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(root, hasTree(expected)); + } + + /** As {@link #testAggregateEliminatesDuplicateCalls()} but with a + * multi-column GROUP BY clause. */ + @Test public void testAggregateEliminatesDuplicateCalls3() { + final RelBuilder builder = RelBuilder.create(config().build()); + RelNode root = buildRelWithDuplicateAggregates(builder, 2, 0, 4, 3); + final String expected = "" + + "LogicalProject(EMPNO=[$0], JOB=[$1], MGR=[$2], HIREDATE=[$3], S1=[$4], C=[$5], S2=[$6], S1b=[$4])\n" + + " LogicalAggregate(group=[{0, 2, 3, 4}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(root, hasTree(expected)); + } + + private RelNode buildRelWithDuplicateAggregates(RelBuilder builder, + int... groupFieldOrdinals) { + return builder.scan("EMP") + .aggregate(builder.groupKey(groupFieldOrdinals), + builder.sum(builder.field(1)).as("S1"), + builder.count().as("C"), + builder.sum(builder.field(2)).as("S2"), + builder.sum(builder.field(1)).as("S1b")) + .build(); } @Test public void testAggregateFilter() {
