[CALCITE-1980] RelBuilder.aggregate should rename underlying fields if groupKey contains alias
Test case by Pavel Gubin, in the following PR; did not use the rest of the PR. Close apache/calcite#535 Project: http://git-wip-us.apache.org/repos/asf/calcite/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/2773c484 Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/2773c484 Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/2773c484 Branch: refs/heads/master Commit: 2773c4846a67360de8301680e375779ce3b1304b Parents: 43fa8e9 Author: Julian Hyde <[email protected]> Authored: Wed Sep 13 12:00:35 2017 -0700 Committer: Julian Hyde <[email protected]> Committed: Mon Oct 2 11:13:42 2017 -0700 ---------------------------------------------------------------------- .../org/apache/calcite/tools/RelBuilder.java | 92 ++++++++++++-------- .../org/apache/calcite/test/RelBuilderTest.java | 44 ++++++++++ 2 files changed, 102 insertions(+), 34 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/2773c484/core/src/main/java/org/apache/calcite/tools/RelBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index fe822f0..0a726c7 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -1049,10 +1049,12 @@ public class RelBuilder { /** Creates an {@link org.apache.calcite.rel.core.Aggregate} with a list of * calls. */ public RelBuilder aggregate(GroupKey groupKey, Iterable<AggCall> aggCalls) { - final List<RexNode> extraNodes = new ArrayList<>(fields()); + final Registrar registrar = new Registrar(); + registrar.extraNodes.addAll(fields()); + registrar.names.addAll(peek().getRowType().getFieldNames()); final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey; final ImmutableBitSet groupSet = - ImmutableBitSet.of(registerExpressions(extraNodes, groupKey_.nodes)); + ImmutableBitSet.of(registrar.registerExpressions(groupKey_.nodes)); label: if (Iterables.isEmpty(aggCalls) && !groupKey_.indicator) { final RelMetadataQuery mq = peek().getCluster().getMetadataQuery(); @@ -1064,10 +1066,12 @@ public class RelBuilder { break label; } } - final Boolean unique = mq.areColumnsUnique(peek(), groupSet); - if (unique != null && unique) { - // Rel is already unique. Nothing to do. - return this; + if (registrar.extraNodes.size() == fields().size()) { + final Boolean unique = mq.areColumnsUnique(peek(), groupSet); + if (unique != null && unique) { + // Rel is already unique. Nothing to do. + return this; + } } final Double maxRowCount = mq.getMaxRowCount(peek()); if (maxRowCount != null && maxRowCount <= 1D) { @@ -1077,12 +1081,12 @@ public class RelBuilder { } final ImmutableList<ImmutableBitSet> groupSets; if (groupKey_.nodeLists != null) { - final int sizeBefore = extraNodes.size(); + final int sizeBefore = registrar.extraNodes.size(); final SortedSet<ImmutableBitSet> groupSetSet = new TreeSet<>(ImmutableBitSet.ORDERING); for (ImmutableList<RexNode> nodeList : groupKey_.nodeLists) { final ImmutableBitSet groupSet2 = - ImmutableBitSet.of(registerExpressions(extraNodes, nodeList)); + ImmutableBitSet.of(registrar.registerExpressions(nodeList)); if (!groupSet.contains(groupSet2)) { throw new IllegalArgumentException("group set element " + nodeList + " must be a subset of group key"); @@ -1090,10 +1094,11 @@ public class RelBuilder { groupSetSet.add(groupSet2); } groupSets = ImmutableList.copyOf(groupSetSet); - if (extraNodes.size() > sizeBefore) { + if (registrar.extraNodes.size() > sizeBefore) { throw new IllegalArgumentException( "group sets contained expressions not in group key: " - + extraNodes.subList(sizeBefore, extraNodes.size())); + + registrar.extraNodes.subList(sizeBefore, + registrar.extraNodes.size())); } } else { groupSets = ImmutableList.of(groupSet); @@ -1101,13 +1106,14 @@ public class RelBuilder { for (AggCall aggCall : aggCalls) { if (aggCall instanceof AggCallImpl) { final AggCallImpl aggCall1 = (AggCallImpl) aggCall; - registerExpressions(extraNodes, aggCall1.operands); + registrar.registerExpressions(aggCall1.operands); if (aggCall1.filter != null) { - registerExpression(extraNodes, aggCall1.filter); + registrar.registerExpression(aggCall1.filter); } } } - project(extraNodes); + project(registrar.extraNodes); + rename(registrar.names); final Frame frame = stack.pop(); final RelNode r = frame.rel; final List<AggregateCall> aggregateCalls = new ArrayList<>(); @@ -1115,9 +1121,10 @@ public class RelBuilder { final AggregateCall aggregateCall; if (aggCall instanceof AggCallImpl) { final AggCallImpl aggCall1 = (AggCallImpl) aggCall; - final List<Integer> args = registerExpressions(extraNodes, aggCall1.operands); + final List<Integer> args = + registrar.registerExpressions(aggCall1.operands); final int filterArg = aggCall1.filter == null ? -1 - : registerExpression(extraNodes, aggCall1.filter); + : registrar.registerExpression(aggCall1.filter); if (aggCall1.distinct && !aggCall1.aggFunction.isQuantifierAllowed()) { throw new IllegalArgumentException("DISTINCT not allowed"); } @@ -1147,7 +1154,7 @@ public class RelBuilder { int i = 0; // first, group fields for (Integer groupField : groupSet.asList()) { - RexNode node = extraNodes.get(groupField); + RexNode node = registrar.extraNodes.get(groupField); final SqlKind kind = node.getKind(); switch (kind) { case INPUT_REF: @@ -1184,24 +1191,6 @@ public class RelBuilder { return this; } - private static int registerExpression(List<RexNode> exprList, RexNode node) { - int i = exprList.indexOf(node); - if (i < 0) { - i = exprList.size(); - exprList.add(node); - } - return i; - } - - private static List<Integer> registerExpressions(List<RexNode> extraNodes, - Iterable<? extends RexNode> nodes) { - final List<Integer> builder = new ArrayList<>(); - for (RexNode node : nodes) { - builder.add(registerExpression(extraNodes, node)); - } - return builder; - } - private RelBuilder setOp(boolean all, SqlKind kind, int n) { List<RelNode> inputs = new LinkedList<>(); for (int i = 0; i < n; i++) { @@ -1800,6 +1789,41 @@ public class RelBuilder { } } + /** Collects the extra expressions needed for {@link #aggregate}. + * + * <p>The extra expressions come from the group key and as arguments to + * aggregate calls, and later there will be a {@link #project} or a + * {@link #rename(List)} if necessary. */ + private static class Registrar { + final List<RexNode> extraNodes = new ArrayList<>(); + final List<String> names = new ArrayList<>(); + + int registerExpression(RexNode node) { + switch (node.getKind()) { + case AS: + final List<RexNode> operands = ((RexCall) node).operands; + int i = registerExpression(operands.get(0)); + names.set(i, RexLiteral.stringValue(operands.get(1))); + return i; + } + int i = extraNodes.indexOf(node); + if (i < 0) { + i = extraNodes.size(); + extraNodes.add(node); + names.add(null); + } + return i; + } + + List<Integer> registerExpressions(Iterable<? extends RexNode> nodes) { + final List<Integer> builder = new ArrayList<>(); + for (RexNode node : nodes) { + builder.add(registerExpression(node)); + } + return builder; + } + } + /** Builder stack frame. * * <p>Describes a previously created relational expression and http://git-wip-us.apache.org/repos/asf/calcite/blob/2773c484/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index 8a337cc..0f66db3 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -702,6 +702,50 @@ public class RelBuilderTest { assertThat(str(root), is(expected)); } + /** Test case for + * <a href="https://issues.apache.org/jira/browse/CALCITE-1980">[CALCITE-1980] + * RelBuilder gives NPE if groupKey contains alias</a>. + * + * <p>Now, the alias does not cause a new expression to be added to the input, + * but causes the referenced fields to be renamed. */ + @Test public void testAggregateProjectWithAliases() { + final RelBuilder builder = RelBuilder.create(config().build()); + RelNode root = + builder.scan("EMP") + .project(builder.field("DEPTNO")) + .aggregate( + builder.groupKey( + builder.alias(builder.field("DEPTNO"), "departmentNo"))) + .build(); + final String expected = "" + + "LogicalAggregate(group=[{0}])\n" + + " LogicalProject(departmentNo=[$0])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(str(root), is(expected)); + } + + @Test public void testAggregateProjectWithExpression() { + final RelBuilder builder = RelBuilder.create(config().build()); + RelNode root = + builder.scan("EMP") + .project(builder.field("DEPTNO")) + .aggregate( + builder.groupKey( + builder.alias( + builder.call(SqlStdOperatorTable.PLUS, + builder.field("DEPTNO"), builder.literal(3)), + "d3"))) + .build(); + final String expected = "" + + "LogicalAggregate(group=[{1}])\n" + + " LogicalProject(DEPTNO=[$0], d3=[$1])\n" + + " LogicalProject(DEPTNO=[$0], $f1=[+($0, 3)])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(str(root), is(expected)); + } + @Test public void testAggregateGroupingKeyOutOfRangeFails() { final RelBuilder builder = RelBuilder.create(config().build()); try {
