This is an automated email from the ASF dual-hosted git repository.
hyuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/master by this push:
new 8a1535f [CALCITE-4652] AggregateExpandDistinctAggregatesRule must
cast top aggregates to original type (Taras Ledkov)
8a1535f is described below
commit 8a1535f94aad1e0ce77797eb84d75b4a5b1692c1
Author: tledkov <[email protected]>
AuthorDate: Fri Jun 4 17:54:17 2021 +0300
[CALCITE-4652] AggregateExpandDistinctAggregatesRule must cast top
aggregates to original type (Taras Ledkov)
Close #2439
---
.../AggregateExpandDistinctAggregatesRule.java | 12 ++++-
.../org/apache/calcite/test/RelOptRulesTest.java | 44 ++++++++++++++++
.../org/apache/calcite/test/SqlToRelTestBase.java | 58 +++++++++++++---------
.../org/apache/calcite/test/RelOptRulesTest.xml | 21 ++++++++
4 files changed, 111 insertions(+), 24 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index 6ef9dae..cec3e58 100644
---
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -26,6 +26,7 @@ import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
@@ -366,12 +367,15 @@ public final class AggregateExpandDistinctAggregatesRule
final int arg = bottomGroups.size() + nonDistinctAggCallProcessedSoFar;
final List<Integer> newArgs = ImmutableList.of(arg);
if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
+ RelDataTypeFactory typeFactory =
aggregate.getCluster().getTypeFactory();
+
newCall =
AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(),
newArgs, -1, aggCall.distinctKeys, aggCall.collation,
originalGroupSet.cardinality(), relBuilder.peek(),
- aggCall.getType(), aggCall.getName());
+ typeFactory.getTypeSystem().deriveSumType(typeFactory,
aggCall.getType()),
+ aggCall.getName());
} else {
newCall =
AggregateCall.create(aggCall.getAggregation(), false,
@@ -400,6 +404,12 @@ public final class AggregateExpandDistinctAggregatesRule
relBuilder.push(
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
+
+ // Add projection node for case: SUM of COUNT(*):
+ // Type of the SUM may be larger than type of COUNT.
+ // CAST to original type must be added.
+ relBuilder.convert(aggregate.getRowType(), true);
+
return relBuilder;
}
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index d2b53c1..809289b 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -89,6 +89,7 @@ import org.apache.calcite.rel.rules.UnionMergeRule;
import org.apache.calcite.rel.rules.ValuesReduceRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeSystemImpl;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
@@ -107,6 +108,7 @@ import org.apache.calcite.sql.fun.SqlLibrary;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.sql.validate.SqlMonotonicity;
@@ -136,6 +138,7 @@ import java.util.List;
import java.util.Locale;
import java.util.function.Function;
import java.util.function.Predicate;
+import java.util.function.Supplier;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -6769,4 +6772,45 @@ class RelOptRulesTest extends RelOptTestBase {
relFn(relFn).with(hepPlanner).checkUnchanged();
}
}
+
+ /**
+ * Test case for <a
href="https://issues.apache.org/jira/browse/CALCITE-4652">[CALCITE-4652]
+ * AggregateExpandDistinctAggregatesRule must cast top aggregates to
original type</a>.
+ * <p>
+ * Checks AggregateExpandDistinctAggregatesRule when return type of the SUM
aggregate
+ * is changed (expanded) by define custom type factory.
+ */
+ @Test void testDistinctCountWithExpandSumType() {
+ // Define new type system to expand SUM return type.
+ RelDataTypeSystemImpl typeSystem = new RelDataTypeSystemImpl() {
+ @Override public RelDataType deriveSumType(RelDataTypeFactory
typeFactory,
+ RelDataType argumentType) {
+ switch (argumentType.getSqlTypeName()) {
+ case INTEGER:
+ case BIGINT:
+ return typeFactory.createSqlType(SqlTypeName.DECIMAL);
+
+ default:
+ return super.deriveSumType(typeFactory, argumentType);
+ }
+ }
+ };
+
+ Supplier<RelDataTypeFactory> typeFactorySupplier = () -> new
SqlTypeFactoryImpl(typeSystem);
+
+ // Expected plan:
+ // LogicalProject(EXPR$0=[CAST($0):BIGINT NOT NULL], EXPR$1=[$1])
+ // LogicalAggregate(group=[{}], EXPR$0=[$SUM0($1)], EXPR$1=[COUNT($0)])
+ // LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
+ // LogicalProject(COMM=[$6])
+ // LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ //
+ // The top 'LogicalProject' must be added in case SUM type is expanded
+ // because type of original expression 'COUNT(DISTINCT comm)' is BIGINT
+ // and type of SUM (of BIGINT) is DECIMAL.
+ sql("SELECT count(comm), COUNT(DISTINCT comm) FROM emp")
+ .withTester(t -> t.withTypeFactorySupplier(typeFactorySupplier))
+ .withRule(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
+ .check();
+ }
}
diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java
b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java
index 93921f1..85d581b 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlToRelTestBase.java
@@ -74,6 +74,7 @@ import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.TestUtil;
+import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
@@ -81,6 +82,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
+import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -100,6 +102,8 @@ public abstract class SqlToRelTestBase {
//~ Static fields/initializers ---------------------------------------------
protected static final String NL = System.getProperty("line.separator");
+ protected static final Supplier<RelDataTypeFactory>
DEFAULT_TYPE_FACTORY_SUPPLIER =
+ Suppliers.memoize(() -> new
SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT));
//~ Instance fields --------------------------------------------------------
@@ -111,7 +115,7 @@ public abstract class SqlToRelTestBase {
final TesterImpl tester =
new TesterImpl(getDiffRepos(), false, false, false, true, null, null,
MockRelOptPlanner::new, UnaryOperator.identity(),
- SqlConformanceEnum.DEFAULT, UnaryOperator.identity());
+ SqlConformanceEnum.DEFAULT, UnaryOperator.identity(),
DEFAULT_TYPE_FACTORY_SUPPLIER);
return tester.withConfig(c ->
c.withTrimUnusedFields(true)
.withExpand(true)
@@ -287,6 +291,9 @@ public abstract class SqlToRelTestBase {
/** Returns a tester that uses a given context. */
Tester withContext(UnaryOperator<Context> transform);
+ /** Returns a tester that uses a type factory. */
+ Tester withTypeFactorySupplier(Supplier<RelDataTypeFactory>
typeFactorySupplier);
+
/** Trims a RelNode. */
RelNode trimRelNode(RelNode relNode);
@@ -564,7 +571,7 @@ public abstract class SqlToRelTestBase {
private final SqlConformance conformance;
private final SqlTestFactory.MockCatalogReaderFactory catalogReaderFactory;
private final Function<RelOptCluster, RelOptCluster> clusterFactory;
- private RelDataTypeFactory typeFactory;
+ private final Supplier<RelDataTypeFactory> typeFactorySupplier;
private final UnaryOperator<SqlToRelConverter.Config> configTransform;
private final UnaryOperator<Context> contextTransform;
@@ -572,7 +579,7 @@ public abstract class SqlToRelTestBase {
protected TesterImpl(DiffRepository diffRepos) {
this(diffRepos, true, true, false, true, null, null,
MockRelOptPlanner::new, UnaryOperator.identity(),
- SqlConformanceEnum.DEFAULT, c -> Contexts.empty());
+ SqlConformanceEnum.DEFAULT, c -> Contexts.empty(),
DEFAULT_TYPE_FACTORY_SUPPLIER);
}
/**
@@ -591,7 +598,8 @@ public abstract class SqlToRelTestBase {
Function<RelOptCluster, RelOptCluster> clusterFactory,
Function<Context, RelOptPlanner> plannerFactory,
UnaryOperator<SqlToRelConverter.Config> configTransform,
- SqlConformance conformance, UnaryOperator<Context> contextTransform) {
+ SqlConformance conformance, UnaryOperator<Context> contextTransform,
+ Supplier<RelDataTypeFactory> typeFactorySupplier) {
this.diffRepos = diffRepos;
this.enableDecorrelate = enableDecorrelate;
this.enableTrim = enableTrim;
@@ -603,6 +611,7 @@ public abstract class SqlToRelTestBase {
this.plannerFactory = Objects.requireNonNull(plannerFactory,
"plannerFactory");
this.conformance = Objects.requireNonNull(conformance, "conformance");
this.contextTransform = Objects.requireNonNull(contextTransform,
"contextTransform");
+ this.typeFactorySupplier = Objects.requireNonNull(typeFactorySupplier,
"typeFactorySupplier");
}
public RelRoot convertSqlToRel(String sql) {
@@ -667,7 +676,7 @@ public abstract class SqlToRelTestBase {
return createSqlToRelConverter(
validator,
catalogReader,
- typeFactory,
+ getTypeFactory(),
config);
}
@@ -689,14 +698,7 @@ public abstract class SqlToRelTestBase {
}
protected final RelDataTypeFactory getTypeFactory() {
- if (typeFactory == null) {
- typeFactory = createTypeFactory();
- }
- return typeFactory;
- }
-
- protected RelDataTypeFactory createTypeFactory() {
- return new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
+ return typeFactorySupplier.get();
}
protected final RelOptPlanner getPlanner() {
@@ -899,7 +901,7 @@ public abstract class SqlToRelTestBase {
: new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public TesterImpl withLateDecorrelation(boolean enableLateDecorrelate) {
@@ -908,7 +910,7 @@ public abstract class SqlToRelTestBase {
: new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public Tester withConfig(UnaryOperator<SqlToRelConverter.Config>
transform) {
@@ -917,7 +919,7 @@ public abstract class SqlToRelTestBase {
return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public TesterImpl withTrim(boolean enableTrim) {
@@ -926,21 +928,21 @@ public abstract class SqlToRelTestBase {
: new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public TesterImpl withConformance(SqlConformance conformance) {
return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public Tester enableTypeCoercion(boolean enableTypeCoercion) {
return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public Tester withCatalogReaderFactory(
@@ -948,7 +950,7 @@ public abstract class SqlToRelTestBase {
return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public Tester withClusterFactory(
@@ -956,7 +958,7 @@ public abstract class SqlToRelTestBase {
return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
}
public Tester withPlannerFactory(
@@ -966,14 +968,24 @@ public abstract class SqlToRelTestBase {
: new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- contextTransform);
+ contextTransform, typeFactorySupplier);
+ }
+
+ public Tester withTypeFactorySupplier(
+ Supplier<RelDataTypeFactory> typeFactorySupplier) {
+ return this.typeFactorySupplier == typeFactorySupplier
+ ? this
+ : new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
+ enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
+ clusterFactory, plannerFactory, configTransform, conformance,
+ contextTransform, typeFactorySupplier);
}
public TesterImpl withContext(UnaryOperator<Context> context) {
return new TesterImpl(diffRepos, enableDecorrelate, enableTrim,
enableLateDecorrelate, enableTypeCoercion, catalogReaderFactory,
clusterFactory, plannerFactory, configTransform, conformance,
- context);
+ context, typeFactorySupplier);
}
public boolean isLateDecorrelate() {
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index f811fba..3ed9020 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2369,6 +2369,27 @@ LogicalProject(DEPTNO=[$0], EXPR$1=[$3], EXPR$2=[$5],
EXPR$3=[$7], EXPR$4=[$1])
]]>
</Resource>
</TestCase>
+ <TestCase name="testDistinctCountWithExpandSumType">
+ <Resource name="sql">
+ <![CDATA[SELECT count(comm), COUNT(DISTINCT comm) FROM emp]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT()], EXPR$1=[COUNT(DISTINCT $0)])
+ LogicalProject(COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(EXPR$0=[CAST($0):BIGINT NOT NULL], EXPR$1=[$1])
+ LogicalAggregate(group=[{}], EXPR$0=[$SUM0($1)], EXPR$1=[COUNT($0)])
+ LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
+ LogicalProject(COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testDistinctCountWithoutGroupBy">
<Resource name="sql">
<![CDATA[select max(deptno), count(distinct ename)