This is an automated email from the ASF dual-hosted git repository.
sereda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/master by this push:
new a75a689 [CALCITE-2820] Avoid reducing certain aggregate functions
when it is not necessary (Siddharth Teotia)
a75a689 is described below
commit a75a689eff2f1333adc8fb800bdfa077e94da562
Author: siddharth <[email protected]>
AuthorDate: Wed Feb 6 12:06:13 2019 -0800
[CALCITE-2820] Avoid reducing certain aggregate functions when
it is not necessary (Siddharth Teotia)
---
.../rel/rules/AggregateReduceFunctionsRule.java | 62 +++++++++--
.../org/apache/calcite/test/RelOptRulesTest.java | 81 ++++++++++++++
.../org/apache/calcite/test/RelOptRulesTest.xml | 123 +++++++++++++++++++++
3 files changed, 254 insertions(+), 12 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index e62c786..b2f416f 100644
---
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -46,9 +46,11 @@ import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
/**
* Planner rule that reduces aggregate functions in
@@ -97,16 +99,59 @@ public class AggregateReduceFunctionsRule extends
RelOptRule {
new AggregateReduceFunctionsRule(operand(LogicalAggregate.class, any()),
RelFactories.LOGICAL_BUILDER);
+ private final EnumSet<SqlKind> functionsToReduce;
+
//~ Constructors -----------------------------------------------------------
- /** Creates an AggregateReduceFunctionsRule. */
+ /**
+ * Creates an AggregateReduceFunctionsRule to reduce all functions
+ * handled by this rule
+ * @param operand operand to determine if rule can be applied
+ * @param relBuilderFactory builder for relational expressions
+ */
public AggregateReduceFunctionsRule(RelOptRuleOperand operand,
RelBuilderFactory relBuilderFactory) {
super(operand, relBuilderFactory, null);
+ functionsToReduce = EnumSet.noneOf(SqlKind.class);
+ addDefaultSetOfFunctionsToReduce();
+ }
+
+ /**
+ * Creates an AggregateReduceFunctionsRule with client
+ * provided information on which specific functions will
+ * be reduced by this rule
+ * @param aggregateClass aggregate class
+ * @param relBuilderFactory builder for relational expressions
+ * @param functionsToReduce client provided information
+ * on which specific functions
+ * will be reduced by this rule
+ */
+ public AggregateReduceFunctionsRule(Class<? extends Aggregate>
aggregateClass,
+ RelBuilderFactory relBuilderFactory, EnumSet<SqlKind> functionsToReduce)
{
+ super(operand(aggregateClass, any()), relBuilderFactory, null);
+ Objects.requireNonNull(functionsToReduce,
+ "Expecting a valid handle for AggregateFunctionsToReduce");
+ this.functionsToReduce = EnumSet.noneOf(SqlKind.class);
+ for (SqlKind function : functionsToReduce) {
+ if (SqlKind.AVG_AGG_FUNCTIONS.contains(function)
+ || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(function)
+ || function == SqlKind.SUM) {
+ this.functionsToReduce.add(function);
+ } else {
+ throw new IllegalArgumentException(
+ "AggregateReduceFunctionsRule doesn't support function: " +
function.sql);
+ }
+ }
}
//~ Methods ----------------------------------------------------------------
+ private void addDefaultSetOfFunctionsToReduce() {
+ functionsToReduce.addAll(SqlKind.AVG_AGG_FUNCTIONS);
+ functionsToReduce.addAll(SqlKind.COVAR_AVG_AGG_FUNCTIONS);
+ functionsToReduce.add(SqlKind.SUM);
+ }
+
@Override public boolean matches(RelOptRuleCall call) {
if (!super.matches(call)) {
return false;
@@ -138,20 +183,13 @@ public class AggregateReduceFunctionsRule extends
RelOptRule {
* Returns whether the aggregate call is a reducible function
*/
private boolean isReducible(final SqlKind kind) {
- if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)
- || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind)) {
- return true;
- }
- switch (kind) {
- case SUM:
- return true;
- }
- return false;
+ return functionsToReduce.contains(kind);
}
/**
- * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
- * the aggregates list to.
+ * Reduces calls to functions AVG, SUM, STDDEV_POP, STDDEV_SAMP, VAR_POP,
+ * VAR_SAMP, COVAR_POP, COVAR_SAMP, REGR_SXX, REGR_SYY if the function is
+ * present in {@link AggregateReduceFunctionsRule#functionsToReduce}
*
* <p>It handles newly generated common subexpressions since this was done
* at the sql2rel stage.
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index b4ce16e..d033818 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -44,6 +44,7 @@ import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.Union;
+import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalTableModify;
@@ -137,6 +138,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.EnumSet;
import java.util.List;
import java.util.Properties;
import java.util.function.Predicate;
@@ -4347,6 +4349,85 @@ public class RelOptRulesTest extends RelOptTestBase {
final String planAfter = NL + RelOptUtil.toString(relAfter);
diffRepos.assertEquals("planAfter", "${planAfter}", planAfter);
}
+
+ @Test public void testReduceAverageWithNoReduceSum() {
+ final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG);
+ checkPlanning(
+ new AggregateReduceFunctionsRule(
+ LogicalAggregate.class, RelFactories.LOGICAL_BUILDER,
+ functionsToReduce),
+ "select name, max(name), avg(deptno), min(name)"
+ + " from sales.dept group by name");
+ }
+
+ @Test public void testNoReduceAverage() {
+ final EnumSet<SqlKind> functionsToReduce = EnumSet.noneOf(SqlKind.class);
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(
+ new AggregateReduceFunctionsRule(LogicalAggregate.class,
+ RelFactories.LOGICAL_BUILDER, functionsToReduce))
+ .build();
+ String sql = "select name, max(name), avg(deptno), min(name)"
+ + " from sales.dept group by name";
+ sql(sql).with(program).checkUnchanged();
+ }
+
+ @Test public void testNoReduceSum() {
+ final EnumSet<SqlKind> functionsToReduce = EnumSet.noneOf(SqlKind.class);
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(
+ new AggregateReduceFunctionsRule(LogicalAggregate.class,
+ RelFactories.LOGICAL_BUILDER, functionsToReduce))
+ .build();
+ String sql = "select name, sum(deptno)"
+ + " from sales.dept group by name";
+ sql(sql).with(program).checkUnchanged();
+ }
+
+ @Test public void testReduceAverageAndVarWithNoReduceStddev() {
+ // configure rule to reduce AVG and VAR_POP functions
+ // other functions like SUM, STDDEV won't be reduced
+ final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG,
SqlKind.VAR_POP);
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(
+ new AggregateReduceFunctionsRule(LogicalAggregate.class,
+ RelFactories.LOGICAL_BUILDER, functionsToReduce))
+ .build();
+ final String sql = "select name, stddev_pop(deptno), avg(deptno),"
+ + " var_pop(deptno)\n"
+ + "from sales.dept group by name";
+ sql(sql).with(program).check();
+ }
+
+ @Test public void testReduceAverageAndSumWithNoReduceStddevAndVar() {
+ // configure rule to reduce AVG and SUM functions
+ // other functions like VAR_POP, STDDEV_POP won't be reduced
+ final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG,
SqlKind.SUM);
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(
+ new AggregateReduceFunctionsRule(LogicalAggregate.class,
+ RelFactories.LOGICAL_BUILDER, functionsToReduce))
+ .build();
+ final String sql = "select name, stddev_pop(deptno), avg(deptno),"
+ + " var_pop(deptno)\n"
+ + "from sales.dept group by name";
+ sql(sql).with(program).check();
+ }
+
+ @Test public void testReduceAllAggregateFunctions() {
+ // configure rule to reduce all used functions
+ final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG,
SqlKind.SUM,
+ SqlKind.STDDEV_POP, SqlKind.STDDEV_SAMP, SqlKind.VAR_POP,
SqlKind.VAR_SAMP);
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(
+ new AggregateReduceFunctionsRule(LogicalAggregate.class,
+ RelFactories.LOGICAL_BUILDER, functionsToReduce))
+ .build();
+ final String sql = "select name, stddev_pop(deptno), avg(deptno),"
+ + " stddev_samp(deptno), var_pop(deptno), var_samp(deptno)\n"
+ + "from sales.dept group by name";
+ sql(sql).with(program).check();
+ }
}
// End RelOptRulesTest.java
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 45af69c..adfd110 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -8874,4 +8874,127 @@ LogicalSortExchange(distribution=[hash[1]],
collation=[[1]])
]]>
</Resource>
</TestCase>
+ <TestCase name="testReduceAverageWithNoReduceSum">
+ <Resource name="sql">
+ <![CDATA[select name, max(name), avg(deptno), min(name) from
sales.dept group by name]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], EXPR$2=[AVG($1)],
EXPR$3=[MIN($0)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT
NULL], EXPR$3=[$4])
+ LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], agg#1=[SUM($1)],
agg#2=[COUNT()], EXPR$3=[MIN($0)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testNoReduceAverage">
+ <Resource name="sql">
+ <![CDATA[select name, max(name), avg(deptno), min(name) from
sales.dept group by name]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], EXPR$2=[AVG($1)],
EXPR$3=[MIN($0)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], EXPR$2=[AVG($1)],
EXPR$3=[MIN($0)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testNoReduceSum">
+ <Resource name="sql">
+ <![CDATA[select name, sum(deptno) from sales.dept group by name]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testReduceAverageAndVarWithNoReduceStddev">
+ <Resource name="sql">
+ <![CDATA[select name, stddev_pop(deptno), avg(deptno),
stddev_samp(deptno),var_pop(deptno), var_samp(deptno)
+from sales.dept group by name]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], EXPR$2=[AVG($1)],
EXPR$3=[VAR_POP($1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT
NULL], EXPR$3=[CAST(/(-($4, /(*($2, $2), $3)), $3)):INTEGER NOT NULL])
+ LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], agg#1=[SUM($1)],
agg#2=[COUNT()], agg#3=[SUM($2)])
+ LogicalProject(NAME=[$0], DEPTNO=[$1], $f2=[*($1, $1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testReduceAverageAndSumWithNoReduceStddevAndVar">
+ <Resource name="sql">
+ <![CDATA[select name, stddev_pop(deptno), avg(deptno),
stddev_samp(deptno),var_pop(deptno), var_samp(deptno)
+from sales.dept group by name]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], EXPR$2=[AVG($1)],
EXPR$3=[VAR_POP($1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT
NULL], EXPR$3=[$4])
+ LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], agg#1=[$SUM0($1)],
agg#2=[COUNT()], EXPR$3=[VAR_POP($1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testReduceAllAggregateFunctions">
+ <Resource name="sql">
+ <![CDATA[select name, stddev_pop(deptno), avg(deptno),
stddev_samp(deptno),var_pop(deptno), var_samp(deptno)
+from sales.dept group by name]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], EXPR$2=[AVG($1)],
EXPR$3=[STDDEV_SAMP($1)], EXPR$4=[VAR_POP($1)], EXPR$5=[VAR_SAMP($1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[CAST(POWER(/(-($1, /(*($2, $2), $3)), $3),
0.5:DECIMAL(2, 1))):INTEGER NOT NULL], EXPR$2=[CAST(/($2, $3)):INTEGER NOT
NULL], EXPR$3=[CAST(POWER(/(-($1, /(*($2, $2), $3)), CASE(=($3, 1),
null:BIGINT, -($3, 1))), 0.5:DECIMAL(2, 1))):INTEGER NOT NULL],
EXPR$4=[CAST(/(-($1, /(*($2, $2), $3)), $3)):INTEGER NOT NULL],
EXPR$5=[CAST(/(-($1, /(*($2, $2), $3)), CASE(=($3, 1), null:BIGINT, -($3,
1)))):INTEGER NOT NULL])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[$SUM0($1)],
agg#2=[COUNT()])
+ LogicalProject(NAME=[$0], DEPTNO=[$1], $f2=[*($1, $1)])
+ LogicalProject(NAME=[$1], DEPTNO=[$0])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
</Root>