This is an automated email from the ASF dual-hosted git repository. jhyde pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 467e509f8a5348ac83534ec46b873b6645524990 Author: Julian Hyde <[email protected]> AuthorDate: Fri Feb 11 18:12:27 2022 -0800 [CALCITE-5802] In RelBuilder, add method aggregateRex, to allow aggregating complex expressions such as "1 + SUM(x + 2)" --- .../java/org/apache/calcite/tools/RelBuilder.java | 68 +++++++++++ .../org/apache/calcite/test/RelBuilderTest.java | 126 +++++++++++++++++++-- 2 files changed, 185 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index f0ae2f37cd..a99245a805 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -2611,6 +2611,29 @@ public class RelBuilder { && groupKey.isSimple(); } + /** Creates an {@link Aggregate} with a set of hybrid expressions represented + * as {@link RexNode}. */ + public RelBuilder aggregateRex(GroupKey groupKey, + RexNode... nodes) { + return aggregateRex(groupKey, false, ImmutableList.copyOf(nodes)); + } + + /** Creates an {@link Aggregate} with a set of hybrid expressions represented + * as {@link RexNode}, optionally projecting the {@code groupKey} columns. */ + public RelBuilder aggregateRex(GroupKey groupKey, boolean projectKey, + Iterable<? extends RexNode> nodes) { + final GroupKeyImpl groupKeyImpl = (GroupKeyImpl) groupKey; + final AggBuilder aggBuilder = new AggBuilder(groupKeyImpl.nodes); + for (RexNode node : nodes) { + aggBuilder.add(node); + } + return aggregate(groupKey, aggBuilder.aggCalls) + .project( + Iterables.concat( + fields(Util.range(projectKey ? groupKey.groupKeyCount() : 0)), + aggBuilder.postProjects)); + } + /** Finishes the implementation of {@link #aggregate} by creating an * {@link Aggregate} and pushing it onto the stack. */ private RelBuilder aggregate_(ImmutableBitSet groupSet, @@ -4972,4 +4995,49 @@ public class RelBuilder { Config withRemoveRedundantDistinct(boolean removeRedundantDistinct); } + /** Working state for {@link #aggregateRex}. */ + private class AggBuilder { + final ImmutableList<RexNode> groupKeys; + final List<RexNode> postProjects = new ArrayList<>(); + final List<AggCall> aggCalls = new ArrayList<>(); + + private AggBuilder(ImmutableList<RexNode> groupKeys) { + this.groupKeys = groupKeys; + } + + /** Adds a node that may or may not contain an aggregate function. */ + void add(RexNode node) { + postProjects.add(convert(node)); + } + + /** Adds a node that we know to contain an aggregate function, and returns + * an expression whose input row type is the output row type of the + * aggregate layer ({@link #groupKeys} and {@link #aggCalls}). */ + private RexNode convert(RexNode node) { + final RexBuilder rexBuilder = cluster.getRexBuilder(); + if (node instanceof RexCall) { + final RexCall call = (RexCall) node; + if (call.getOperator().isAggregator()) { + final AggCall aggCall = + aggregateCall((SqlAggFunction) call.op, call.operands); + final int i = groupKeys.size() + aggCalls.size(); + aggCalls.add(aggCall); + return rexBuilder.makeInputRef(call.getType(), i); + } else { + final List<RexNode> operands = new ArrayList<>(); + call.operands.forEach(operand -> + operands.add(convert(operand))); + return call.clone(call.type, operands); + } + } else if (node instanceof RexInputRef) { + final int j = groupKeys.indexOf(node); + if (j < 0) { + throw new IllegalArgumentException("not a group key: " + node); + } + return rexBuilder.makeInputRef(node.getType(), j); + } else { + return node; + } + } + } } diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index e71b816bc8..a746adfba1 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -121,6 +121,7 @@ import java.util.function.Function; import java.util.function.UnaryOperator; import java.util.stream.Collectors; +import static org.apache.calcite.test.Matchers.hasFieldNames; import static org.apache.calcite.test.Matchers.hasHints; import static org.apache.calcite.test.Matchers.hasTree; @@ -1158,7 +1159,7 @@ public class RelBuilderTest { .rename(ImmutableList.of("x", "y z")) .build(); assertThat(root, hasTree(expected)); - assertThat(root.getRowType().getFieldNames(), hasToString("[x, y z]")); + assertThat(root, hasFieldNames("[x, y z]")); } /** Tests conditional rename using {@link RelBuilder#let}. */ @@ -2166,7 +2167,7 @@ public class RelBuilderTest { * GROUP_ID()</a>. */ @Test void testAggregateGroupingSetsGroupId() { final String plan = "" - + "LogicalProject(JOB=[$0], DEPTNO=[$1], $f2=[0:BIGINT])\n" + + "LogicalProject(JOB=[$0], DEPTNO=[$1], g=[0:BIGINT])\n" + " LogicalAggregate(group=[{2, 7}], groups=[[{2, 7}, {2}, {7}]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; assertThat(groupIdRel(createBuilder(), false), hasTree(plan)); @@ -2177,10 +2178,10 @@ public class RelBuilderTest { // If any group occurs more than once, we need a UNION ALL. final String plan2 = "" + "LogicalUnion(all=[true])\n" - + " LogicalProject(JOB=[$0], DEPTNO=[$1], $f2=[0:BIGINT])\n" + + " LogicalProject(JOB=[$0], DEPTNO=[$1], g=[0:BIGINT])\n" + " LogicalAggregate(group=[{2, 7}], groups=[[{2, 7}, {2}, {7}]])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(JOB=[$0], DEPTNO=[$1], $f2=[1:BIGINT])\n" + + " LogicalProject(JOB=[$0], DEPTNO=[$1], g=[1:BIGINT])\n" + " LogicalAggregate(group=[{2, 7}])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; assertThat(groupIdRel(createBuilder(), true), hasTree(plan2)); @@ -2200,7 +2201,7 @@ public class RelBuilderTest { .addAll(extra ? ImmutableList.of(builder.fields(djList)) : ImmutableList.of()) .build()), - builder.aggregateCall(SqlStdOperatorTable.GROUP_ID)) + builder.aggregateCall(SqlStdOperatorTable.GROUP_ID).as("g")) .build(); } @@ -3279,6 +3280,112 @@ public class RelBuilderTest { assertThat(root, hasTree(expected)); } + /** Test case for + * <a href="https://issues.apache.org/jira/browse/CALCITE-5802">[CALCITE-5802] + * In RelBuilder, add method aggregateRex, to allow aggregating complex + * expressions such as "1 + SUM(x + 2)"</a>. */ + @Test void testAggregateRex() { + // SELECT deptno, + // deptno + 2 AS d2, + // 3 + SUM(4 + sal) AS s + // FROM emp + // GROUP BY deptno + Function<RelBuilder, RelNode> f = b -> + b.scan("EMP") + .aggregateRex(b.groupKey(b.field("DEPTNO")), + b.field("DEPTNO"), + b.alias( + b.call(SqlStdOperatorTable.PLUS, b.field("DEPTNO"), + b.literal(2)), + "d2"), + b.alias( + b.call(SqlStdOperatorTable.PLUS, b.literal(3), + b.call(SqlStdOperatorTable.SUM, + b.call(SqlStdOperatorTable.PLUS, b.literal(4), + b.field("SAL")))), + "s")) + .build(); + final String expected = "" + + "LogicalProject(DEPTNO=[$0], d2=[+($0, 2)], s=[+(3, $1)])\n" + + " LogicalAggregate(group=[{0}], agg#0=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], $f8=[+(4, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + final String expectedRowType = + "RecordType(TINYINT DEPTNO, INTEGER d2, DECIMAL(19, 2) s) NOT NULL"; + final RelNode r = f.apply(createBuilder()); + assertThat(r, hasTree(expected)); + assertThat(r.getRowType().getFullTypeString(), is(expectedRowType)); + } + + /** Tests {@link RelBuilder#aggregateRex} with an expression; + * it needs to be evaluated post aggregation. */ + @Test void testAggregateRex2() { + // SELECT CURRENT_DATE AS d + // FROM emp + // GROUP BY () + BiFunction<RelBuilder, Boolean, RelNode> f = (b, projectKey) -> + b.scan("EMP") + .aggregateRex(b.groupKey(), projectKey, + ImmutableList.of( + b.alias(b.call(SqlStdOperatorTable.CURRENT_DATE), "d"))) + .build(); + final String expected = "" + + "LogicalProject(d=[CURRENT_DATE])\n" + + " LogicalValues(tuples=[[{ true }]])\n"; + final String expectedRowType = "RecordType(DATE NOT NULL d) NOT NULL"; + final RelNode r = f.apply(createBuilder(), false); + assertThat(r, hasTree(expected)); + assertThat(r.getRowType().getFullTypeString(), is(expectedRowType)); + + // As above, with projectKey = true + final RelNode r2 = f.apply(createBuilder(), true); + assertThat(r2, hasTree(expected)); + assertThat(r2.getRowType().getFullTypeString(), is(expectedRowType)); + + // As above, disabling extra fields + final String expected3 = "" + + "LogicalProject(d=[CURRENT_DATE])\n" + + " LogicalValues(tuples=[[{ }]])\n"; + final RelNode r3 = + f.apply(createBuilder(c -> c.withPreventEmptyFieldList(false)), + false); + assertThat(r3, hasTree(expected3)); + assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType)); + } + + /** Tests {@link RelBuilder#aggregateRex} with a literal expression; + * it needs to be evaluated post aggregation. */ + @Test void testAggregateRex3() { + // SELECT 2 AS two, false AS f + // FROM emp + // GROUP BY () + BiFunction<RelBuilder, Boolean, RelNode> f = (b, projectKey) -> + b.scan("EMP") + .aggregateRex(b.groupKey(), projectKey, + ImmutableList.of(b.alias(b.literal(2), "two"), + b.alias(b.literal(false), "f"))) + .build(); + final String expected = + "LogicalValues(tuples=[[{ 2, false }]])\n"; + final String expectedRowType = + "RecordType(INTEGER NOT NULL two, BOOLEAN NOT NULL f) NOT NULL"; + final RelNode r = f.apply(createBuilder(), false); + assertThat(r, hasTree(expected)); + assertThat(r.getRowType().getFullTypeString(), is(expectedRowType)); + + // As above, with projectKey = true + final RelNode r2 = f.apply(createBuilder(), true); + assertThat(r2, hasTree(expected)); + assertThat(r2.getRowType().getFullTypeString(), is(expectedRowType)); + + // As above, disabling extra fields + final RelNode r3 = + f.apply(createBuilder(c -> c.withPreventEmptyFieldList(false)), + false); + assertThat(r3, hasTree(expected)); + assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType)); + } + /** Tests that a projection retains field names after a join. */ @Test void testProjectJoin() { final RelBuilder builder = RelBuilder.create(config().build()); @@ -3765,10 +3872,11 @@ public class RelBuilderTest { .build(); final String expected = "LogicalValues(tuples=[[{ 1, true }, { 2, false }]])\n"; - final String expectedRowType = "RecordType(INTEGER x, BOOLEAN y)"; - assertThat(f.apply(createBuilder()), hasTree(expected)); - assertThat(f.apply(createBuilder()).getRowType(), - hasToString(expectedRowType)); + final String expectedRowType = + "RecordType(INTEGER NOT NULL x, BOOLEAN NOT NULL y) NOT NULL"; + final RelNode r = f.apply(createBuilder()); + assertThat(r, hasTree(expected)); + assertThat(r.getRowType().getFullTypeString(), is(expectedRowType)); } /** Tests that {@code Union(Project(Values), ... Project(Values))} is
