This is an automated email from the ASF dual-hosted git repository.

sereda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new a75a689  [CALCITE-2820] Avoid reducing certain aggregate functions 
when it is not necessary (Siddharth Teotia)
a75a689 is described below

commit a75a689eff2f1333adc8fb800bdfa077e94da562
Author: siddharth <[email protected]>
AuthorDate: Wed Feb 6 12:06:13 2019 -0800

    [CALCITE-2820] Avoid reducing certain aggregate functions when
    it is not necessary (Siddharth Teotia)
---
 .../rel/rules/AggregateReduceFunctionsRule.java    |  62 +++++++++--
 .../org/apache/calcite/test/RelOptRulesTest.java   |  81 ++++++++++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 123 +++++++++++++++++++++
 3 files changed, 254 insertions(+), 12 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index e62c786..b2f416f 100644
--- 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -46,9 +46,11 @@ import com.google.common.collect.ImmutableList;
 import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 /**
  * Planner rule that reduces aggregate functions in
@@ -97,16 +99,59 @@ public class AggregateReduceFunctionsRule extends 
RelOptRule {
       new AggregateReduceFunctionsRule(operand(LogicalAggregate.class, any()),
           RelFactories.LOGICAL_BUILDER);
 
+  private final EnumSet<SqlKind> functionsToReduce;
+
   //~ Constructors -----------------------------------------------------------
 
-  /** Creates an AggregateReduceFunctionsRule. */
+  /**
+   * Creates an AggregateReduceFunctionsRule to reduce all functions
+   * handled by this rule
+   * @param operand operand to determine if rule can be applied
+   * @param relBuilderFactory builder for relational expressions
+   */
   public AggregateReduceFunctionsRule(RelOptRuleOperand operand,
       RelBuilderFactory relBuilderFactory) {
     super(operand, relBuilderFactory, null);
+    functionsToReduce = EnumSet.noneOf(SqlKind.class);
+    addDefaultSetOfFunctionsToReduce();
+  }
+
+  /**
+   * Creates an AggregateReduceFunctionsRule with client
+   * provided information on which specific functions will
+   * be reduced by this rule
+   * @param aggregateClass aggregate class
+   * @param relBuilderFactory builder for relational expressions
+   * @param functionsToReduce client provided information
+   *                          on which specific functions
+   *                          will be reduced by this rule
+   */
+  public AggregateReduceFunctionsRule(Class<? extends Aggregate> 
aggregateClass,
+      RelBuilderFactory relBuilderFactory, EnumSet<SqlKind> functionsToReduce) 
{
+    super(operand(aggregateClass, any()), relBuilderFactory, null);
+    Objects.requireNonNull(functionsToReduce,
+        "Expecting a valid handle for AggregateFunctionsToReduce");
+    this.functionsToReduce = EnumSet.noneOf(SqlKind.class);
+    for (SqlKind function : functionsToReduce) {
+      if (SqlKind.AVG_AGG_FUNCTIONS.contains(function)
+          || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(function)
+          || function == SqlKind.SUM) {
+        this.functionsToReduce.add(function);
+      } else {
+        throw new IllegalArgumentException(
+          "AggregateReduceFunctionsRule doesn't support function: " + 
function.sql);
+      }
+    }
   }
 
   //~ Methods ----------------------------------------------------------------
 
+  private void addDefaultSetOfFunctionsToReduce() {
+    functionsToReduce.addAll(SqlKind.AVG_AGG_FUNCTIONS);
+    functionsToReduce.addAll(SqlKind.COVAR_AVG_AGG_FUNCTIONS);
+    functionsToReduce.add(SqlKind.SUM);
+  }
+
   @Override public boolean matches(RelOptRuleCall call) {
     if (!super.matches(call)) {
       return false;
@@ -138,20 +183,13 @@ public class AggregateReduceFunctionsRule extends 
RelOptRule {
    * Returns whether the aggregate call is a reducible function
    */
   private boolean isReducible(final SqlKind kind) {
-    if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)
-        || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind)) {
-      return true;
-    }
-    switch (kind) {
-    case SUM:
-      return true;
-    }
-    return false;
+    return functionsToReduce.contains(kind);
   }
 
   /**
-   * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
-   * the aggregates list to.
+   * Reduces calls to functions AVG, SUM, STDDEV_POP, STDDEV_SAMP, VAR_POP,
+   * VAR_SAMP, COVAR_POP, COVAR_SAMP, REGR_SXX, REGR_SYY if the function is
+   * present in {@link AggregateReduceFunctionsRule#functionsToReduce}
    *
    * <p>It handles newly generated common subexpressions since this was done
    * at the sql2rel stage.
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index b4ce16e..d033818 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -44,6 +44,7 @@ import org.apache.calcite.rel.core.Minus;
 import org.apache.calcite.rel.core.Project;
 import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.core.Union;
+import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalCorrelate;
 import org.apache.calcite.rel.logical.LogicalProject;
 import org.apache.calcite.rel.logical.LogicalTableModify;
@@ -137,6 +138,7 @@ import org.junit.Test;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.EnumSet;
 import java.util.List;
 import java.util.Properties;
 import java.util.function.Predicate;
@@ -4347,6 +4349,85 @@ public class RelOptRulesTest extends RelOptTestBase {
     final String planAfter = NL + RelOptUtil.toString(relAfter);
     diffRepos.assertEquals("planAfter", "${planAfter}", planAfter);
   }
+
+  @Test public void testReduceAverageWithNoReduceSum() {
+    final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG);
+    checkPlanning(
+        new AggregateReduceFunctionsRule(
+          LogicalAggregate.class, RelFactories.LOGICAL_BUILDER,
+            functionsToReduce),
+                  "select name, max(name), avg(deptno), min(name)"
+                          + " from sales.dept group by name");
+  }
+
+  @Test public void testNoReduceAverage() {
+    final EnumSet<SqlKind> functionsToReduce = EnumSet.noneOf(SqlKind.class);
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(
+          new AggregateReduceFunctionsRule(LogicalAggregate.class,
+            RelFactories.LOGICAL_BUILDER, functionsToReduce))
+        .build();
+    String sql = "select name, max(name), avg(deptno), min(name)"
+        + " from sales.dept group by name";
+    sql(sql).with(program).checkUnchanged();
+  }
+
+  @Test public void testNoReduceSum() {
+    final EnumSet<SqlKind> functionsToReduce = EnumSet.noneOf(SqlKind.class);
+    HepProgram program = new HepProgramBuilder()
+            .addRuleInstance(
+              new AggregateReduceFunctionsRule(LogicalAggregate.class,
+                RelFactories.LOGICAL_BUILDER, functionsToReduce))
+            .build();
+    String sql = "select name, sum(deptno)"
+            + " from sales.dept group by name";
+    sql(sql).with(program).checkUnchanged();
+  }
+
+  @Test public void testReduceAverageAndVarWithNoReduceStddev() {
+    // configure rule to reduce AVG and VAR_POP functions
+    // other functions like SUM, STDDEV won't be reduced
+    final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG, 
SqlKind.VAR_POP);
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(
+         new AggregateReduceFunctionsRule(LogicalAggregate.class,
+           RelFactories.LOGICAL_BUILDER, functionsToReduce))
+        .build();
+    final String sql = "select name, stddev_pop(deptno), avg(deptno),"
+        + " var_pop(deptno)\n"
+        + "from sales.dept group by name";
+    sql(sql).with(program).check();
+  }
+
+  @Test public void testReduceAverageAndSumWithNoReduceStddevAndVar() {
+    // configure rule to reduce AVG and SUM functions
+    // other functions like VAR_POP, STDDEV_POP won't be reduced
+    final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG, 
SqlKind.SUM);
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(
+          new AggregateReduceFunctionsRule(LogicalAggregate.class,
+            RelFactories.LOGICAL_BUILDER, functionsToReduce))
+        .build();
+    final String sql = "select name, stddev_pop(deptno), avg(deptno),"
+        + " var_pop(deptno)\n"
+        + "from sales.dept group by name";
+    sql(sql).with(program).check();
+  }
+
+  @Test public void testReduceAllAggregateFunctions() {
+    // configure rule to reduce all used functions
+    final EnumSet<SqlKind> functionsToReduce = EnumSet.of(SqlKind.AVG, 
SqlKind.SUM,
+        SqlKind.STDDEV_POP, SqlKind.STDDEV_SAMP, SqlKind.VAR_POP, 
SqlKind.VAR_SAMP);
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(
+          new AggregateReduceFunctionsRule(LogicalAggregate.class,
+            RelFactories.LOGICAL_BUILDER, functionsToReduce))
+        .build();
+    final String sql = "select name, stddev_pop(deptno), avg(deptno),"
+        + " stddev_samp(deptno), var_pop(deptno), var_samp(deptno)\n"
+        + "from sales.dept group by name";
+    sql(sql).with(program).check();
+  }
 }
 
 // End RelOptRulesTest.java
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 45af69c..adfd110 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -8874,4 +8874,127 @@ LogicalSortExchange(distribution=[hash[1]], 
collation=[[1]])
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testReduceAverageWithNoReduceSum">
+        <Resource name="sql">
+            <![CDATA[select name, max(name), avg(deptno), min(name) from 
sales.dept group by name]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], EXPR$2=[AVG($1)], 
EXPR$3=[MIN($0)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT 
NULL], EXPR$3=[$4])
+  LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], agg#1=[SUM($1)], 
agg#2=[COUNT()], EXPR$3=[MIN($0)])
+    LogicalProject(NAME=[$1], DEPTNO=[$0])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testNoReduceAverage">
+        <Resource name="sql">
+            <![CDATA[select name, max(name), avg(deptno), min(name) from 
sales.dept group by name]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], EXPR$2=[AVG($1)], 
EXPR$3=[MIN($0)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($0)], EXPR$2=[AVG($1)], 
EXPR$3=[MIN($0)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testNoReduceSum">
+        <Resource name="sql">
+            <![CDATA[select name, sum(deptno) from sales.dept group by name]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testReduceAverageAndVarWithNoReduceStddev">
+        <Resource name="sql">
+            <![CDATA[select name, stddev_pop(deptno), avg(deptno), 
stddev_samp(deptno),var_pop(deptno), var_samp(deptno)
+from sales.dept group by name]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], EXPR$2=[AVG($1)], 
EXPR$3=[VAR_POP($1)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT 
NULL], EXPR$3=[CAST(/(-($4, /(*($2, $2), $3)), $3)):INTEGER NOT NULL])
+  LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], agg#1=[SUM($1)], 
agg#2=[COUNT()], agg#3=[SUM($2)])
+    LogicalProject(NAME=[$0], DEPTNO=[$1], $f2=[*($1, $1)])
+      LogicalProject(NAME=[$1], DEPTNO=[$0])
+        LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testReduceAverageAndSumWithNoReduceStddevAndVar">
+        <Resource name="sql">
+            <![CDATA[select name, stddev_pop(deptno), avg(deptno), 
stddev_samp(deptno),var_pop(deptno), var_samp(deptno)
+from sales.dept group by name]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], EXPR$2=[AVG($1)], 
EXPR$3=[VAR_POP($1)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT 
NULL], EXPR$3=[$4])
+  LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], agg#1=[$SUM0($1)], 
agg#2=[COUNT()], EXPR$3=[VAR_POP($1)])
+    LogicalProject(NAME=[$1], DEPTNO=[$0])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testReduceAllAggregateFunctions">
+        <Resource name="sql">
+            <![CDATA[select name, stddev_pop(deptno), avg(deptno), 
stddev_samp(deptno),var_pop(deptno), var_samp(deptno)
+from sales.dept group by name]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)], EXPR$2=[AVG($1)], 
EXPR$3=[STDDEV_SAMP($1)], EXPR$4=[VAR_POP($1)], EXPR$5=[VAR_SAMP($1)])
+  LogicalProject(NAME=[$1], DEPTNO=[$0])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(NAME=[$0], EXPR$1=[CAST(POWER(/(-($1, /(*($2, $2), $3)), $3), 
0.5:DECIMAL(2, 1))):INTEGER NOT NULL], EXPR$2=[CAST(/($2, $3)):INTEGER NOT 
NULL], EXPR$3=[CAST(POWER(/(-($1, /(*($2, $2), $3)), CASE(=($3, 1), 
null:BIGINT, -($3, 1))), 0.5:DECIMAL(2, 1))):INTEGER NOT NULL], 
EXPR$4=[CAST(/(-($1, /(*($2, $2), $3)), $3)):INTEGER NOT NULL], 
EXPR$5=[CAST(/(-($1, /(*($2, $2), $3)), CASE(=($3, 1), null:BIGINT, -($3, 
1)))):INTEGER NOT NULL])
+  LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[$SUM0($1)], 
agg#2=[COUNT()])
+    LogicalProject(NAME=[$0], DEPTNO=[$1], $f2=[*($1, $1)])
+      LogicalProject(NAME=[$1], DEPTNO=[$0])
+        LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
 </Root>

Reply via email to