This is an automated email from the ASF dual-hosted git repository.
jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 8771e3f94b [CALCITE-6555] RelBuilder.aggregateRex wrongly thinks
aggregate functions of "GROUP BY ()" queries are NOT NULL
8771e3f94b is described below
commit 8771e3f94b2c4f132bfa51a0b93ddb7a11cd97bc
Author: Julian Hyde <[email protected]>
AuthorDate: Thu Aug 29 13:30:18 2024 -0700
[CALCITE-6555] RelBuilder.aggregateRex wrongly thinks aggregate functions
of "GROUP BY ()" queries are NOT NULL
In RelBuilder, the aggregateRex method (added in CALCITE-5802)
wrongly thinks that aggregate functions in a `GROUP BY ()`
query are NOT NULL. Consider the query
SELECT SUM(empno) AS s, COUNT(empno) AS c
FROM emp
GROUP BY ()
`SUM(empno)` should be nullable, even though `empno` has type
`SMALLINT NOT NULL`, because `GROUP BY ()` will return one row
even if `emp` has no rows, and therefore `SUM` will be
evaluated over the empty set. A RelBuilder test that attempts
to build an equivalent query gets the following error stack:
java.lang.AssertionError: type mismatch:
ref:
SMALLINT NOT NULL
input:
SMALLINT
We add a test case for measure queries, because measures are
the only code path that uses `aggregateRex` at present.
---
.../java/org/apache/calcite/tools/RelBuilder.java | 105 +++++++++++++++------
.../org/apache/calcite/test/RelBuilderTest.java | 23 +++++
core/src/test/resources/sql/measure.iq | 21 +++++
3 files changed, 119 insertions(+), 30 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 69f7c43d58..e6c88d7808 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -150,6 +150,7 @@ import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
@@ -2632,14 +2633,22 @@ public class RelBuilder {
Iterable<? extends RexNode> nodes) {
final GroupKeyImpl groupKeyImpl = (GroupKeyImpl) groupKey;
final AggBuilder aggBuilder = new AggBuilder(groupKeyImpl.nodes);
- for (RexNode node : nodes) {
- aggBuilder.add(node);
+
+ // First pass. Call convert on each expression to ensure that aggCalls
+ // gets populated.
+ aggBuilder.registerExpressions(nodes);
+
+ // Create the Aggregate on the stack.
+ aggregate(groupKey, aggBuilder.aggCalls);
+
+ // Second pass. Call convert on each expression so that it references the
+ // actual aggCalls in the Aggregate that was just pushed onto the stack.
+ final List<RexNode> projects = new ArrayList<>();
+ if (projectKey) {
+ projects.addAll(fields(Util.range(groupKey.groupKeyCount())));
}
- return aggregate(groupKey, aggBuilder.aggCalls)
- .project(
- Iterables.concat(
- fields(Util.range(projectKey ? groupKey.groupKeyCount() : 0)),
- aggBuilder.postProjects));
+ aggBuilder.convertExpressions(projects::add, nodes);
+ return project(projects);
}
/** Finishes the implementation of {@link #aggregate} by creating an
@@ -5040,46 +5049,82 @@ public class RelBuilder {
/** Working state for {@link #aggregateRex}. */
private class AggBuilder {
final ImmutableList<RexNode> groupKeys;
- final List<RexNode> postProjects = new ArrayList<>();
final List<AggCall> aggCalls = new ArrayList<>();
private AggBuilder(ImmutableList<RexNode> groupKeys) {
this.groupKeys = groupKeys;
}
- /** Adds a node that may or may not contain an aggregate function. */
- void add(RexNode node) {
- postProjects.add(convert(node));
- }
-
/** Adds a node that we know to contain an aggregate function, and returns
* an expression whose input row type is the output row type of the
* aggregate layer ({@link #groupKeys} and {@link #aggCalls}). */
- private RexNode convert(RexNode node) {
- final RexBuilder rexBuilder = cluster.getRexBuilder();
- if (node instanceof RexCall) {
- final RexCall call = (RexCall) node;
- if (call.getOperator().isAggregator()) {
- final AggCall aggCall =
- aggregateCall((SqlAggFunction) call.op, call.operands);
- final int i = groupKeys.size() + aggCalls.size();
- aggCalls.add(aggCall);
- return rexBuilder.makeInputRef(call.getType(), i);
+ private RexNode convert(RegisterAgg registrar, RexNode node,
+ @Nullable String name) {
+ switch (node.getKind()) {
+ case AS:
+ final ImmutableList<RexNode> asOperands = ((RexCall) node).operands;
+ final String name2;
+ if (name != null) {
+ name2 = name;
} else {
- final List<RexNode> operands = new ArrayList<>();
- call.operands.forEach(operand ->
- operands.add(convert(operand)));
- return call.clone(call.type, operands);
+ final RexLiteral literal = (RexLiteral) asOperands.get(1);
+ name2 = requireNonNull(literal.getValueAs(String.class));
}
- } else if (node instanceof RexInputRef) {
+ final RexNode node2 = convert(registrar, asOperands.get(0), name2);
+ return alias(node2, name2);
+
+ case INPUT_REF:
final int j = groupKeys.indexOf(node);
if (j < 0) {
throw new IllegalArgumentException("not a group key: " + node);
}
- return rexBuilder.makeInputRef(node.getType(), j);
- } else {
+ return field(j);
+
+ default:
+ if (node instanceof RexCall) {
+ final RexCall call = (RexCall) node;
+ if (call.getOperator().isAggregator()) {
+ // return a reference to the i'th agg call
+ return registrar.registerAgg((SqlAggFunction) call.op,
+ call.operands, call.type, name);
+ } else {
+ return call.clone(call.type,
+ Util.transform(call.operands, operand ->
+ convert(registrar, operand, null)));
+ }
+ }
return node;
}
}
+
+ void registerExpressions(Iterable<? extends RexNode> nodes) {
+ for (RexNode node : nodes) {
+ convert(this::registerAgg, node, null);
+ }
+ }
+
+ RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
+ RelDataType type, @Nullable String name) {
+ final int i = groupKeys.size() + aggCalls.size();
+ aggCalls.add(aggregateCall(op, operands).as(name));
+ return getRexBuilder().makeInputRef(type, i);
+ }
+
+ void convertExpressions(Consumer<RexNode> projects,
+ Iterable<? extends RexNode> nodes) {
+ final AtomicInteger j = new AtomicInteger(groupKeys.size());
+ for (RexNode node : nodes) {
+ projects.accept(
+ convert((op, operands, type, name) -> field(j.getAndIncrement()),
+ node, null));
+ }
+ }
+ }
+
+ /** Callback to handle creation of an aggregate call in
+ * {@link AggBuilder#convert}. */
+ private interface RegisterAgg {
+ RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
+ RelDataType type, @Nullable String name);
}
}
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index c2818989ce..829d76f5b0 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -3398,6 +3398,29 @@ public class RelBuilderTest {
assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType));
}
+ /** Tests {@link RelBuilder#aggregateRex} with an aggregate call that needs
to
+ * become nullable because of "GROUP BY ()". */
+ @Test void testAggregateRex4() {
+ // SELECT SUM(sal) AS s, COUNT(sal) AS c
+ // FROM emp
+ // GROUP BY ()
+ Function<RelBuilder, RelNode> f = b ->
+ b.scan("EMP")
+ .aggregateRex(b.groupKey(),
+ b.alias(b.call(SqlStdOperatorTable.SUM, b.field("EMPNO")),
"s"),
+ b.alias(b.call(SqlStdOperatorTable.COUNT, b.field("SAL")),
"c"))
+ .build();
+ final String expected =
+ "LogicalAggregate(group=[{}], s=[SUM($0)], c=[COUNT($5)])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n";
+ // s is nullable because "GROUP BY ()" may have a group that contains 0
rows
+ final String expectedRowType =
+ "RecordType(SMALLINT s, BIGINT NOT NULL c) NOT NULL";
+ final RelNode r = f.apply(createBuilder());
+ assertThat(r, hasTree(expected));
+ assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
+ }
+
/** Tests that a projection retains field names after a join. */
@Test void testProjectJoin() {
final RelBuilder builder = RelBuilder.create(config().build());
diff --git a/core/src/test/resources/sql/measure.iq
b/core/src/test/resources/sql/measure.iq
index 8a59142424..e6aa11131e 100644
--- a/core/src/test/resources/sql/measure.iq
+++ b/core/src/test/resources/sql/measure.iq
@@ -84,6 +84,27 @@ group by job;
!ok
+# Measure on primary key gives type error (casting away NOT NULL); cause was
+# [CALCITE-6555] RelBuilder.aggregateRex thinks aggregate functions of
+# "GROUP BY ()" queries are NOT NULL
+with empm as (
+ select *, min(empno) as measure avg_sal
+ from emp
+)
+select deptno, avg_sal as a
+from empm
+group by deptno;
++--------+------+
+| DEPTNO | A |
++--------+------+
+| 10 | 7782 |
+| 20 | 7369 |
+| 30 | 7499 |
++--------+------+
+(3 rows)
+
+!ok
+
# Equivalent using AGGREGATE
select job, aggregate(avg_sal) as a
from empm