This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/calcite.git
commit e01ba5ab6e7c57348f9f7be2babf00ae007204b5 Author: Julian Hyde <[email protected]> AuthorDate: Fri Jun 7 15:56:13 2019 -0700 [CALCITE-3123] In RelBuilder, eliminate duplicate aggregate calls --- .../java/org/apache/calcite/tools/RelBuilder.java | 44 ++++++++++++++++++++-- .../org/apache/calcite/test/RelBuilderTest.java | 19 ++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index c139f58..cdf71c4 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -1602,7 +1602,45 @@ public class RelBuilder { for (ImmutableBitSet set : groupSets) { assert groupSet.contains(set); } - RelNode aggregate = aggregateFactory.createAggregate(r, + + if (Util.isDistinct(aggregateCalls)) { + return aggregate_(groupSet, groupSets, r, aggregateCalls, + registrar.extraNodes, frame.fields); + } else { + // There are duplicate aggregate calls. + final Set<AggregateCall> callSet = new HashSet<>(); + final List<Integer> projects = + new ArrayList<>(Util.range(groupSet.cardinality())); + final List<AggregateCall> distinctAggregateCalls = new ArrayList<>(); + for (AggregateCall aggregateCall : aggregateCalls) { + final int i; + if (callSet.add(aggregateCall)) { + i = distinctAggregateCalls.size(); + distinctAggregateCalls.add(aggregateCall); + } else { + i = distinctAggregateCalls.indexOf(aggregateCall); + assert i >= 0; + } + projects.add(i); + } + aggregate_(groupSet, groupSets, r, distinctAggregateCalls, + registrar.extraNodes, frame.fields); + final List<RexNode> fields = + new ArrayList<>(fields(Util.range(groupSet.cardinality()))); + for (Ord<Integer> project : Ord.zip(projects)) { + fields.add(alias(field(project.e), aggregateCalls.get(project.i).name)); + } + return project(fields); + } + } + + /** Finishes the implementation of {@link #aggregate} by creating an + * {@link Aggregate} and pushing it onto the stack. */ + private RelBuilder aggregate_(ImmutableBitSet groupSet, + ImmutableList<ImmutableBitSet> groupSets, RelNode input, + List<AggregateCall> aggregateCalls, List<RexNode> extraNodes, + ImmutableList<Field> inFields) { + final RelNode aggregate = aggregateFactory.createAggregate(input, groupSet, groupSets, aggregateCalls); // build field list @@ -1612,11 +1650,11 @@ public class RelBuilder { int i = 0; // first, group fields for (Integer groupField : groupSet.asList()) { - RexNode node = registrar.extraNodes.get(groupField); + RexNode node = extraNodes.get(groupField); final SqlKind kind = node.getKind(); switch (kind) { case INPUT_REF: - fields.add(frame.fields.get(((RexInputRef) node).getIndex())); + fields.add(inFields.get(((RexInputRef) node).getIndex())); break; default: String name = aggregateFields.get(i).getName(); diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index a4f28af..c06bdeb 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -927,6 +927,25 @@ public class RelBuilderTest { assertThat(root, hasTree(expected)); } + /** Tests that {@link RelBuilder#aggregate} eliminates duplicate aggregate + * calls and creates a {@code Project} to compensate. */ + @Test public void testAggregateEliminatesDuplicateCalls() { + final RelBuilder builder = RelBuilder.create(config().build()); + RelNode root = + builder.scan("EMP") + .aggregate(builder.groupKey(), + builder.sum(builder.field(1)).as("S1"), + builder.count().as("C"), + builder.sum(builder.field(2)).as("S2"), + builder.sum(builder.field(1)).as("S1b")) + .build(); + final String expected = "" + + "LogicalProject(S1=[$0], C=[$1], S2=[$2], S1b=[$0])\n" + + " LogicalAggregate(group=[{}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(root, hasTree(expected)); + } + @Test public void testAggregateFilter() { // Equivalent SQL: // SELECT deptno, COUNT(*) FILTER (WHERE empno > 100) AS c
