This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 8771e3f94b [CALCITE-6555] RelBuilder.aggregateRex wrongly thinks 
aggregate functions of "GROUP BY ()" queries are NOT NULL
8771e3f94b is described below

commit 8771e3f94b2c4f132bfa51a0b93ddb7a11cd97bc
Author: Julian Hyde <[email protected]>
AuthorDate: Thu Aug 29 13:30:18 2024 -0700

    [CALCITE-6555] RelBuilder.aggregateRex wrongly thinks aggregate functions 
of "GROUP BY ()" queries are NOT NULL
    
    In RelBuilder, the aggregateRex method (added in CALCITE-5802)
    wrongly thinks that aggregate functions in a `GROUP BY ()`
    query are NOT NULL. Consider the query
    
      SELECT SUM(empno) AS s, COUNT(empno) AS c
      FROM emp
      GROUP BY ()
    
    `SUM(empno)` should be nullable, even though `empno` has type
    `SMALLINT NOT NULL`, because `GROUP BY ()` will return one row
    even if `emp` has no rows, and therefore `SUM` will be
    evaluated over the empty set. A RelBuilder test that attempts
    to build an equivalent query gets the following error stack:
    
      java.lang.AssertionError: type mismatch:
      ref:
      SMALLINT NOT NULL
      input:
      SMALLINT
    
    We add a test case for measure queries, because measures are
    the only code path that uses `aggregateRex` at present.
---
 .../java/org/apache/calcite/tools/RelBuilder.java  | 105 +++++++++++++++------
 .../org/apache/calcite/test/RelBuilderTest.java    |  23 +++++
 core/src/test/resources/sql/measure.iq             |  21 +++++
 3 files changed, 119 insertions(+), 30 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java 
b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 69f7c43d58..e6c88d7808 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -150,6 +150,7 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -2632,14 +2633,22 @@ public class RelBuilder {
       Iterable<? extends RexNode> nodes) {
     final GroupKeyImpl groupKeyImpl = (GroupKeyImpl) groupKey;
     final AggBuilder aggBuilder = new AggBuilder(groupKeyImpl.nodes);
-    for (RexNode node : nodes) {
-      aggBuilder.add(node);
+
+    // First pass. Call convert on each expression to ensure that aggCalls
+    // gets populated.
+    aggBuilder.registerExpressions(nodes);
+
+    // Create the Aggregate on the stack.
+    aggregate(groupKey, aggBuilder.aggCalls);
+
+    // Second pass. Call convert on each expression so that it references the
+    // actual aggCalls in the Aggregate that was just pushed onto the stack.
+    final List<RexNode> projects = new ArrayList<>();
+    if (projectKey) {
+      projects.addAll(fields(Util.range(groupKey.groupKeyCount())));
     }
-    return aggregate(groupKey, aggBuilder.aggCalls)
-        .project(
-            Iterables.concat(
-                fields(Util.range(projectKey ? groupKey.groupKeyCount() : 0)),
-                aggBuilder.postProjects));
+    aggBuilder.convertExpressions(projects::add, nodes);
+    return project(projects);
   }
 
   /** Finishes the implementation of {@link #aggregate} by creating an
@@ -5040,46 +5049,82 @@ public class RelBuilder {
   /** Working state for {@link #aggregateRex}. */
   private class AggBuilder {
     final ImmutableList<RexNode> groupKeys;
-    final List<RexNode> postProjects = new ArrayList<>();
     final List<AggCall> aggCalls = new ArrayList<>();
 
     private AggBuilder(ImmutableList<RexNode> groupKeys) {
       this.groupKeys = groupKeys;
     }
 
-    /** Adds a node that may or may not contain an aggregate function. */
-    void add(RexNode node) {
-      postProjects.add(convert(node));
-    }
-
     /** Adds a node that we know to contain an aggregate function, and returns
      * an expression whose input row type is the output row type of the
      * aggregate layer ({@link #groupKeys} and {@link #aggCalls}). */
-    private RexNode convert(RexNode node) {
-      final RexBuilder rexBuilder = cluster.getRexBuilder();
-      if (node instanceof RexCall) {
-        final RexCall call = (RexCall) node;
-        if (call.getOperator().isAggregator()) {
-          final AggCall aggCall =
-              aggregateCall((SqlAggFunction) call.op, call.operands);
-          final int i = groupKeys.size() + aggCalls.size();
-          aggCalls.add(aggCall);
-          return rexBuilder.makeInputRef(call.getType(), i);
+    private RexNode convert(RegisterAgg registrar, RexNode node,
+        @Nullable String name) {
+      switch (node.getKind()) {
+      case AS:
+        final ImmutableList<RexNode> asOperands = ((RexCall) node).operands;
+        final String name2;
+        if (name != null) {
+          name2 = name;
         } else {
-          final List<RexNode> operands = new ArrayList<>();
-          call.operands.forEach(operand ->
-              operands.add(convert(operand)));
-          return call.clone(call.type, operands);
+          final RexLiteral literal = (RexLiteral) asOperands.get(1);
+          name2 = requireNonNull(literal.getValueAs(String.class));
         }
-      } else if (node instanceof RexInputRef) {
+        final RexNode node2 = convert(registrar, asOperands.get(0), name2);
+        return alias(node2, name2);
+
+      case INPUT_REF:
         final int j = groupKeys.indexOf(node);
         if (j < 0) {
           throw new IllegalArgumentException("not a group key: " + node);
         }
-        return rexBuilder.makeInputRef(node.getType(), j);
-      } else {
+        return field(j);
+
+      default:
+        if (node instanceof RexCall) {
+          final RexCall call = (RexCall) node;
+          if (call.getOperator().isAggregator()) {
+            // return a reference to the i'th agg call
+            return registrar.registerAgg((SqlAggFunction) call.op,
+                call.operands, call.type, name);
+          } else {
+            return call.clone(call.type,
+                Util.transform(call.operands, operand ->
+                    convert(registrar, operand, null)));
+          }
+        }
         return node;
       }
     }
+
+    void registerExpressions(Iterable<? extends RexNode> nodes) {
+      for (RexNode node : nodes) {
+        convert(this::registerAgg, node, null);
+      }
+    }
+
+    RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
+        RelDataType type, @Nullable String name) {
+      final int i = groupKeys.size() + aggCalls.size();
+      aggCalls.add(aggregateCall(op, operands).as(name));
+      return getRexBuilder().makeInputRef(type, i);
+    }
+
+    void convertExpressions(Consumer<RexNode> projects,
+        Iterable<? extends RexNode> nodes) {
+      final AtomicInteger j = new AtomicInteger(groupKeys.size());
+      for (RexNode node : nodes) {
+        projects.accept(
+            convert((op, operands, type, name) -> field(j.getAndIncrement()),
+                node, null));
+      }
+    }
+  }
+
+  /** Callback to handle creation of an aggregate call in
+   * {@link AggBuilder#convert}. */
+  private interface RegisterAgg {
+    RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
+        RelDataType type, @Nullable String name);
   }
 }
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java 
b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index c2818989ce..829d76f5b0 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -3398,6 +3398,29 @@ public class RelBuilderTest {
     assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType));
   }
 
+  /** Tests {@link RelBuilder#aggregateRex} with an aggregate call that needs 
to
+   * become nullable because of "GROUP BY ()". */
+  @Test void testAggregateRex4() {
+    // SELECT SUM(sal) AS s, COUNT(sal) AS c
+    // FROM emp
+    // GROUP BY ()
+    Function<RelBuilder, RelNode> f = b ->
+        b.scan("EMP")
+            .aggregateRex(b.groupKey(),
+                b.alias(b.call(SqlStdOperatorTable.SUM, b.field("EMPNO")), 
"s"),
+                b.alias(b.call(SqlStdOperatorTable.COUNT, b.field("SAL")), 
"c"))
+            .build();
+    final String expected =
+        "LogicalAggregate(group=[{}], s=[SUM($0)], c=[COUNT($5)])\n"
+            + "  LogicalTableScan(table=[[scott, EMP]])\n";
+    // s is nullable because "GROUP BY ()" may have a group that contains 0 
rows
+    final String expectedRowType =
+        "RecordType(SMALLINT s, BIGINT NOT NULL c) NOT NULL";
+    final RelNode r = f.apply(createBuilder());
+    assertThat(r, hasTree(expected));
+    assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
+  }
+
   /** Tests that a projection retains field names after a join. */
   @Test void testProjectJoin() {
     final RelBuilder builder = RelBuilder.create(config().build());
diff --git a/core/src/test/resources/sql/measure.iq 
b/core/src/test/resources/sql/measure.iq
index 8a59142424..e6aa11131e 100644
--- a/core/src/test/resources/sql/measure.iq
+++ b/core/src/test/resources/sql/measure.iq
@@ -84,6 +84,27 @@ group by job;
 
 !ok
 
+# Measure on primary key gives type error (casting away NOT NULL); cause was
+# [CALCITE-6555] RelBuilder.aggregateRex thinks aggregate functions of
+# "GROUP BY ()" queries are NOT NULL
+with empm as (
+  select *, min(empno) as measure avg_sal
+  from emp
+)
+select deptno, avg_sal as a
+from empm
+group by deptno;
++--------+------+
+| DEPTNO | A    |
++--------+------+
+|     10 | 7782 |
+|     20 | 7369 |
+|     30 | 7499 |
++--------+------+
+(3 rows)
+
+!ok
+
 # Equivalent using AGGREGATE
 select job, aggregate(avg_sal) as a
 from empm

Reply via email to