This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 467e509f8a5348ac83534ec46b873b6645524990
Author: Julian Hyde <[email protected]>
AuthorDate: Fri Feb 11 18:12:27 2022 -0800

    [CALCITE-5802] In RelBuilder, add method aggregateRex, to allow aggregating 
complex expressions such as "1 + SUM(x + 2)"
---
 .../java/org/apache/calcite/tools/RelBuilder.java  |  68 +++++++++++
 .../org/apache/calcite/test/RelBuilderTest.java    | 126 +++++++++++++++++++--
 2 files changed, 185 insertions(+), 9 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java 
b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index f0ae2f37cd..a99245a805 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -2611,6 +2611,29 @@ public class RelBuilder {
         && groupKey.isSimple();
   }
 
+  /** Creates an {@link Aggregate} with a set of hybrid expressions represented
+   * as {@link RexNode}. */
+  public RelBuilder aggregateRex(GroupKey groupKey,
+      RexNode... nodes) {
+    return aggregateRex(groupKey, false, ImmutableList.copyOf(nodes));
+  }
+
+  /** Creates an {@link Aggregate} with a set of hybrid expressions represented
+   * as {@link RexNode}, optionally projecting the {@code groupKey} columns. */
+  public RelBuilder aggregateRex(GroupKey groupKey, boolean projectKey,
+      Iterable<? extends RexNode> nodes) {
+    final GroupKeyImpl groupKeyImpl = (GroupKeyImpl) groupKey;
+    final AggBuilder aggBuilder = new AggBuilder(groupKeyImpl.nodes);
+    for (RexNode node : nodes) {
+      aggBuilder.add(node);
+    }
+    return aggregate(groupKey, aggBuilder.aggCalls)
+        .project(
+            Iterables.concat(
+                fields(Util.range(projectKey ? groupKey.groupKeyCount() : 0)),
+                aggBuilder.postProjects));
+  }
+
   /** Finishes the implementation of {@link #aggregate} by creating an
    * {@link Aggregate} and pushing it onto the stack. */
   private RelBuilder aggregate_(ImmutableBitSet groupSet,
@@ -4972,4 +4995,49 @@ public class RelBuilder {
     Config withRemoveRedundantDistinct(boolean removeRedundantDistinct);
   }
 
+  /** Working state for {@link #aggregateRex}. */
+  private class AggBuilder {
+    final ImmutableList<RexNode> groupKeys;
+    final List<RexNode> postProjects = new ArrayList<>();
+    final List<AggCall> aggCalls = new ArrayList<>();
+
+    private AggBuilder(ImmutableList<RexNode> groupKeys) {
+      this.groupKeys = groupKeys;
+    }
+
+    /** Adds a node that may or may not contain an aggregate function. */
+    void add(RexNode node) {
+      postProjects.add(convert(node));
+    }
+
+    /** Adds a node that we know to contain an aggregate function, and returns
+     * an expression whose input row type is the output row type of the
+     * aggregate layer ({@link #groupKeys} and {@link #aggCalls}). */
+    private RexNode convert(RexNode node) {
+      final RexBuilder rexBuilder = cluster.getRexBuilder();
+      if (node instanceof RexCall) {
+        final RexCall call = (RexCall) node;
+        if (call.getOperator().isAggregator()) {
+          final AggCall aggCall =
+              aggregateCall((SqlAggFunction) call.op, call.operands);
+          final int i = groupKeys.size() + aggCalls.size();
+          aggCalls.add(aggCall);
+          return rexBuilder.makeInputRef(call.getType(), i);
+        } else {
+          final List<RexNode> operands = new ArrayList<>();
+          call.operands.forEach(operand ->
+              operands.add(convert(operand)));
+          return call.clone(call.type, operands);
+        }
+      } else if (node instanceof RexInputRef) {
+        final int j = groupKeys.indexOf(node);
+        if (j < 0) {
+          throw new IllegalArgumentException("not a group key: " + node);
+        }
+        return rexBuilder.makeInputRef(node.getType(), j);
+      } else {
+        return node;
+      }
+    }
+  }
 }
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java 
b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index e71b816bc8..a746adfba1 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -121,6 +121,7 @@ import java.util.function.Function;
 import java.util.function.UnaryOperator;
 import java.util.stream.Collectors;
 
+import static org.apache.calcite.test.Matchers.hasFieldNames;
 import static org.apache.calcite.test.Matchers.hasHints;
 import static org.apache.calcite.test.Matchers.hasTree;
 
@@ -1158,7 +1159,7 @@ public class RelBuilderTest {
             .rename(ImmutableList.of("x", "y z"))
             .build();
     assertThat(root, hasTree(expected));
-    assertThat(root.getRowType().getFieldNames(), hasToString("[x, y z]"));
+    assertThat(root, hasFieldNames("[x, y z]"));
   }
 
   /** Tests conditional rename using {@link RelBuilder#let}. */
@@ -2166,7 +2167,7 @@ public class RelBuilderTest {
    * GROUP_ID()</a>. */
   @Test void testAggregateGroupingSetsGroupId() {
     final String plan = ""
-        + "LogicalProject(JOB=[$0], DEPTNO=[$1], $f2=[0:BIGINT])\n"
+        + "LogicalProject(JOB=[$0], DEPTNO=[$1], g=[0:BIGINT])\n"
         + "  LogicalAggregate(group=[{2, 7}], groups=[[{2, 7}, {2}, {7}]])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
     assertThat(groupIdRel(createBuilder(), false), hasTree(plan));
@@ -2177,10 +2178,10 @@ public class RelBuilderTest {
     // If any group occurs more than once, we need a UNION ALL.
     final String plan2 = ""
         + "LogicalUnion(all=[true])\n"
-        + "  LogicalProject(JOB=[$0], DEPTNO=[$1], $f2=[0:BIGINT])\n"
+        + "  LogicalProject(JOB=[$0], DEPTNO=[$1], g=[0:BIGINT])\n"
         + "    LogicalAggregate(group=[{2, 7}], groups=[[{2, 7}, {2}, {7}]])\n"
         + "      LogicalTableScan(table=[[scott, EMP]])\n"
-        + "  LogicalProject(JOB=[$0], DEPTNO=[$1], $f2=[1:BIGINT])\n"
+        + "  LogicalProject(JOB=[$0], DEPTNO=[$1], g=[1:BIGINT])\n"
         + "    LogicalAggregate(group=[{2, 7}])\n"
         + "      LogicalTableScan(table=[[scott, EMP]])\n";
     assertThat(groupIdRel(createBuilder(), true), hasTree(plan2));
@@ -2200,7 +2201,7 @@ public class RelBuilderTest {
                     .addAll(extra ? ImmutableList.of(builder.fields(djList))
                         : ImmutableList.of())
                     .build()),
-            builder.aggregateCall(SqlStdOperatorTable.GROUP_ID))
+            builder.aggregateCall(SqlStdOperatorTable.GROUP_ID).as("g"))
         .build();
   }
 
@@ -3279,6 +3280,112 @@ public class RelBuilderTest {
     assertThat(root, hasTree(expected));
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-5802";>[CALCITE-5802]
+   * In RelBuilder, add method aggregateRex, to allow aggregating complex
+   * expressions such as "1 + SUM(x + 2)"</a>. */
+  @Test void testAggregateRex() {
+    // SELECT deptno,
+    //   deptno + 2 AS d2,
+    //   3 + SUM(4 + sal) AS s
+    // FROM emp
+    // GROUP BY deptno
+    Function<RelBuilder, RelNode> f = b ->
+        b.scan("EMP")
+            .aggregateRex(b.groupKey(b.field("DEPTNO")),
+                b.field("DEPTNO"),
+                b.alias(
+                    b.call(SqlStdOperatorTable.PLUS, b.field("DEPTNO"),
+                        b.literal(2)),
+                    "d2"),
+                b.alias(
+                    b.call(SqlStdOperatorTable.PLUS, b.literal(3),
+                        b.call(SqlStdOperatorTable.SUM,
+                            b.call(SqlStdOperatorTable.PLUS, b.literal(4),
+                                b.field("SAL")))),
+                    "s"))
+            .build();
+    final String expected = ""
+        + "LogicalProject(DEPTNO=[$0], d2=[+($0, 2)], s=[+(3, $1)])\n"
+        + "  LogicalAggregate(group=[{0}], agg#0=[SUM($1)])\n"
+        + "    LogicalProject(DEPTNO=[$7], $f8=[+(4, $5)])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n";
+    final String expectedRowType =
+        "RecordType(TINYINT DEPTNO, INTEGER d2, DECIMAL(19, 2) s) NOT NULL";
+    final RelNode r = f.apply(createBuilder());
+    assertThat(r, hasTree(expected));
+    assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
+  }
+
+  /** Tests {@link RelBuilder#aggregateRex} with an expression;
+   * it needs to be evaluated post aggregation. */
+  @Test void testAggregateRex2() {
+    // SELECT CURRENT_DATE AS d
+    // FROM emp
+    // GROUP BY ()
+    BiFunction<RelBuilder, Boolean, RelNode> f = (b, projectKey) ->
+        b.scan("EMP")
+            .aggregateRex(b.groupKey(), projectKey,
+                ImmutableList.of(
+                    b.alias(b.call(SqlStdOperatorTable.CURRENT_DATE), "d")))
+            .build();
+    final String expected = ""
+        + "LogicalProject(d=[CURRENT_DATE])\n"
+        + "  LogicalValues(tuples=[[{ true }]])\n";
+    final String expectedRowType = "RecordType(DATE NOT NULL d) NOT NULL";
+    final RelNode r = f.apply(createBuilder(), false);
+    assertThat(r, hasTree(expected));
+    assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
+
+    // As above, with projectKey = true
+    final RelNode r2 = f.apply(createBuilder(), true);
+    assertThat(r2, hasTree(expected));
+    assertThat(r2.getRowType().getFullTypeString(), is(expectedRowType));
+
+    // As above, disabling extra fields
+    final String expected3 = ""
+        + "LogicalProject(d=[CURRENT_DATE])\n"
+        + "  LogicalValues(tuples=[[{  }]])\n";
+    final RelNode r3 =
+        f.apply(createBuilder(c -> c.withPreventEmptyFieldList(false)),
+            false);
+    assertThat(r3, hasTree(expected3));
+    assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType));
+  }
+
+  /** Tests {@link RelBuilder#aggregateRex} with a literal expression;
+   * it needs to be evaluated post aggregation. */
+  @Test void testAggregateRex3() {
+    // SELECT 2 AS two, false AS f
+    // FROM emp
+    // GROUP BY ()
+    BiFunction<RelBuilder, Boolean, RelNode> f = (b, projectKey) ->
+        b.scan("EMP")
+            .aggregateRex(b.groupKey(), projectKey,
+                ImmutableList.of(b.alias(b.literal(2), "two"),
+                    b.alias(b.literal(false), "f")))
+            .build();
+    final String expected =
+        "LogicalValues(tuples=[[{ 2, false }]])\n";
+    final String expectedRowType =
+        "RecordType(INTEGER NOT NULL two, BOOLEAN NOT NULL f) NOT NULL";
+    final RelNode r = f.apply(createBuilder(), false);
+    assertThat(r, hasTree(expected));
+    assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
+
+    // As above, with projectKey = true
+    final RelNode r2 = f.apply(createBuilder(), true);
+    assertThat(r2, hasTree(expected));
+    assertThat(r2.getRowType().getFullTypeString(), is(expectedRowType));
+
+    // As above, disabling extra fields
+    final RelNode r3 =
+        f.apply(createBuilder(c -> c.withPreventEmptyFieldList(false)),
+            false);
+    assertThat(r3, hasTree(expected));
+    assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType));
+  }
+
   /** Tests that a projection retains field names after a join. */
   @Test void testProjectJoin() {
     final RelBuilder builder = RelBuilder.create(config().build());
@@ -3765,10 +3872,11 @@ public class RelBuilderTest {
             .build();
     final String expected =
         "LogicalValues(tuples=[[{ 1, true }, { 2, false }]])\n";
-    final String expectedRowType = "RecordType(INTEGER x, BOOLEAN y)";
-    assertThat(f.apply(createBuilder()), hasTree(expected));
-    assertThat(f.apply(createBuilder()).getRowType(),
-        hasToString(expectedRowType));
+    final String expectedRowType =
+        "RecordType(INTEGER NOT NULL x, BOOLEAN NOT NULL y) NOT NULL";
+    final RelNode r = f.apply(createBuilder());
+    assertThat(r, hasTree(expected));
+    assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
   }
 
   /** Tests that {@code Union(Project(Values), ... Project(Values))} is

Reply via email to