This is an automated email from the ASF dual-hosted git repository.

danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new 7bb120d  [CALCITE-3721] Filter of distinct aggregate call is lost 
after applying AggregateExpandDistinctAggregatesRule (Shuo Cheng)
7bb120d is described below

commit 7bb120dc074fb20ea5802f2a574dd141efebbcb9
Author: shuo.cs <[email protected]>
AuthorDate: Fri Jan 10 12:31:42 2020 +0800

    [CALCITE-3721] Filter of distinct aggregate call is lost after applying 
AggregateExpandDistinctAggregatesRule (Shuo Cheng)
    
    * Generates boolean input field for distinct agg with filter
    * Remove useless exprs in projects
    
    close apache/calcite#1758
---
 .../AggregateExpandDistinctAggregatesRule.java     | 74 +++++++++++++++-------
 .../org/apache/calcite/test/RelOptRulesTest.java   | 29 +++++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 69 ++++++++++++++++++++
 3 files changed, 148 insertions(+), 24 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index 1ba326f..57c50cb 100644
--- 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -16,12 +16,10 @@
  */
 package org.apache.calcite.rel.rules;
 
-import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.plan.Contexts;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelCollations;
-import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.Aggregate.Group;
 import org.apache.calcite.rel.core.AggregateCall;
@@ -403,15 +401,30 @@ public final class AggregateExpandDistinctAggregatesRule 
extends RelOptRule {
       Aggregate aggregate) {
     final Set<ImmutableBitSet> groupSetTreeSet =
         new TreeSet<>(ImmutableBitSet.ORDERING);
+    // GroupSet to distinct filter arg map,
+    // filterArg will be -1 for non-distinct agg call.
+
+    // Using `Set` here because it's possible that two agg calls
+    // have different filterArgs but same groupSet.
+    final Map<ImmutableBitSet, Set<Integer>> distinctFilterArgMap = new 
HashMap<>();
     for (AggregateCall aggCall : aggregate.getAggCallList()) {
+      ImmutableBitSet groupSet;
+      int filterArg;
       if (!aggCall.isDistinct()) {
+        filterArg = -1;
+        groupSet = aggregate.getGroupSet();
         groupSetTreeSet.add(aggregate.getGroupSet());
       } else {
-        groupSetTreeSet.add(
+        filterArg = aggCall.filterArg;
+        groupSet =
             ImmutableBitSet.of(aggCall.getArgList())
-                .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
-                .union(aggregate.getGroupSet()));
+                .setIf(filterArg, filterArg >= 0)
+                .union(aggregate.getGroupSet());
+        groupSetTreeSet.add(groupSet);
       }
+      Set<Integer> filterList = distinctFilterArgMap
+          .computeIfAbsent(groupSet, g -> new HashSet<>());
+      filterList.add(filterArg);
     }
 
     final ImmutableList<ImmutableBitSet> groupSets =
@@ -433,38 +446,52 @@ public final class AggregateExpandDistinctAggregatesRule 
extends RelOptRule {
     relBuilder.push(aggregate.getInput());
     final int groupCount = fullGroupSet.cardinality();
 
-    final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
-    final int z = groupCount + distinctAggCalls.size();
+    // Get the base ordinal of filter args for different groupSets.
+    final Map<Pair<ImmutableBitSet, Integer>, Integer> filters = new 
LinkedHashMap<>();
+    int z = groupCount + distinctAggCalls.size();
+    for (ImmutableBitSet groupSet: groupSets) {
+      Set<Integer> filterArgList = distinctFilterArgMap.get(groupSet);
+      for (Integer filterArg: filterArgList) {
+        filters.put(Pair.of(groupSet, filterArg), z);
+        z += 1;
+      }
+    }
+
     distinctAggCalls.add(
         AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false,
             ImmutableIntList.copyOf(fullGroupSet), -1, RelCollations.EMPTY,
             groupSets.size(), relBuilder.peek(), null, "$g"));
-    for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
-      filters.put(groupSet.e, z + groupSet.i);
-    }
 
     relBuilder.aggregate(
         relBuilder.groupKey(fullGroupSet,
             (Iterable<ImmutableBitSet>) groupSets),
         distinctAggCalls);
-    final RelNode distinct = relBuilder.peek();
 
     // GROUPING returns an integer (0 or 1). Add a project to convert those
     // values to BOOLEAN.
     if (!filters.isEmpty()) {
       final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
       final RexNode nodeZ = nodes.remove(nodes.size() - 1);
-      for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
-        final long v = groupValue(fullGroupSet, entry.getKey());
+      for (Map.Entry<Pair<ImmutableBitSet, Integer>, Integer> entry : 
filters.entrySet()) {
+        final long v = groupValue(fullGroupSet, entry.getKey().left);
+        int distinctFilterArg = remap(fullGroupSet, entry.getKey().right);
+        RexNode expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
+        if (distinctFilterArg > -1) {
+          // 'AND' the filter of the distinct aggregate call and the group 
value.
+          expr = relBuilder.and(expr,
+              relBuilder.call(SqlStdOperatorTable.IS_TRUE,
+                  relBuilder.field(distinctFilterArg)));
+        }
+        // "f" means filter.
         nodes.add(
-            relBuilder.alias(
-                relBuilder.equals(nodeZ, relBuilder.literal(v)),
-                "$g_" + v));
+            relBuilder.alias(expr,
+            "$g_" + v + (distinctFilterArg < 0 ? "" : "_f_" + 
distinctFilterArg)));
       }
       relBuilder.project(nodes);
     }
 
     int x = groupCount;
+    final ImmutableBitSet groupSet = aggregate.getGroupSet();
     final List<AggregateCall> newCalls = new ArrayList<>();
     for (AggregateCall aggCall : aggregate.getAggCallList()) {
       final int newFilterArg;
@@ -473,27 +500,26 @@ public final class AggregateExpandDistinctAggregatesRule 
extends RelOptRule {
       if (!aggCall.isDistinct()) {
         aggregation = SqlStdOperatorTable.MIN;
         newArgList = ImmutableIntList.of(x++);
-        newFilterArg = filters.get(aggregate.getGroupSet());
+        newFilterArg = filters.get(Pair.of(groupSet, -1));
       } else {
         aggregation = aggCall.getAggregation();
         newArgList = remap(fullGroupSet, aggCall.getArgList());
-        newFilterArg =
-            filters.get(
-                ImmutableBitSet.of(aggCall.getArgList())
-                    .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
-                    .union(aggregate.getGroupSet()));
+        final ImmutableBitSet newGroupSet = 
ImmutableBitSet.of(aggCall.getArgList())
+            .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+            .union(groupSet);
+        newFilterArg = filters.get(Pair.of(newGroupSet, aggCall.filterArg));
       }
       final AggregateCall newCall =
           AggregateCall.create(aggregation, false,
               aggCall.isApproximate(), aggCall.ignoreNulls(),
               newArgList, newFilterArg, aggCall.collation,
-              aggregate.getGroupCount(), distinct, null, aggCall.name);
+              aggregate.getGroupCount(), relBuilder.peek(), null, 
aggCall.name);
       newCalls.add(newCall);
     }
 
     relBuilder.aggregate(
         relBuilder.groupKey(
-            remap(fullGroupSet, aggregate.getGroupSet()),
+            remap(fullGroupSet, groupSet),
             (Iterable<ImmutableBitSet>)
                 remap(fullGroupSet, aggregate.getGroupSets())),
         newCalls);
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 7b614e0..01bcff8 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -1482,6 +1482,35 @@ public class RelOptRulesTest extends RelOptTestBase {
     sql(sql).with(program).check();
   }
 
+  @Test public void testDistinctWithFilterWithoutGroupBy() {
+    final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE 
sal > 1000)\n"
+        + "FROM emp";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateExpandDistinctAggregatesRule.INSTANCE)
+        .build();
+    sql(sql).with(program).check();
+  }
+
+  @Test public void testDistinctWithDiffFiltersAndSameGroupSet() {
+    final String sql = "SELECT COUNT(DISTINCT c) FILTER (WHERE d),\n"
+        + "COUNT(DISTINCT d) FILTER (WHERE c)\n"
+        + "FROM (select sal > 1000 is true as c, sal < 500 is true as d, comm 
from emp)";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateExpandDistinctAggregatesRule.INSTANCE)
+        .build();
+    sql(sql).with(program).check();
+  }
+
+  @Test public void testDistinctWithFilterAndGroupBy() {
+    final String sql = "SELECT deptno, SUM(comm), COUNT(DISTINCT sal) FILTER 
(WHERE sal > 1000)\n"
+        + "FROM emp\n"
+        + "GROUP BY deptno";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateExpandDistinctAggregatesRule.INSTANCE)
+        .build();
+    sql(sql).with(program).check();
+  }
+
   @Test public void testPushProjectPastFilter() {
     final String sql = "select empno + deptno from emp where sal = 10 * comm\n"
         + "and upper(ename) = 'FOO'";
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 63016e1..695c171 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -1153,6 +1153,75 @@ LogicalAggregate(group=[{0}], EXPR$1=[SUM($3)], 
EXPR$2=[MIN($4)], EXPR$3=[COUNT(
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testDistinctWithFilterWithoutGroupBy">
+        <Resource name="sql">
+            <![CDATA[SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal > 
1000)
+FROM emp]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0)], EXPR$1=[COUNT(DISTINCT $1) 
FILTER $2])
+  LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[MIN($2) FILTER $4], EXPR$1=[COUNT($0) 
FILTER $3])
+  LogicalProject(SAL=[$0], $f2=[$1], EXPR$0=[$2], $g_0_f_1=[AND(=($3, 0), IS 
TRUE($1))], $g_3=[=($3, 3)])
+    LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)], 
$g=[GROUPING($1, $2)])
+      LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testDistinctWithDiffFiltersAndSameGroupSet">
+        <Resource name="sql">
+            <![CDATA[SELECT COUNT(DISTINCT c) FILTER (WHERE d),
+COUNT(DISTINCT d) FILTER (WHERE c)
+FROM (select sal > 1000 is true as c, sal < 500 is true as d from emp)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT(DISTINCT $0) FILTER $1], 
EXPR$1=[COUNT(DISTINCT $1) FILTER $0])
+  LogicalProject(C=[>($5, 1000)], D=[<($5, 500)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT($0) FILTER $3], EXPR$1=[COUNT($1) 
FILTER $2])
+  LogicalProject(C=[$0], D=[$1], $g_0_f_0=[AND(=($2, 0), $0)], 
$g_0_f_1=[AND(=($2, 0), $1)])
+    LogicalAggregate(group=[{0, 1}], $g=[GROUPING($0, $1)])
+      LogicalProject(C=[>($5, 1000)], D=[<($5, 500)])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testDistinctWithFilterAndGroupBy">
+        <Resource name="sql">
+            <![CDATA[SELECT deptno, SUM(comm), COUNT(DISTINCT sal) FILTER 
(WHERE sal > 1000)
+FROM emp
+GROUP BY deptno]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)], EXPR$2=[COUNT(DISTINCT $2) 
FILTER $3])
+  LogicalProject(DEPTNO=[$7], COMM=[$6], SAL=[$5], $f3=[>($5, 1000)])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CAST($1):INTEGER NOT NULL], EXPR$2=[$2])
+  LogicalAggregate(group=[{0}], EXPR$1=[MIN($3) FILTER $5], EXPR$2=[COUNT($1) 
FILTER $4])
+    LogicalProject(DEPTNO=[$0], SAL=[$1], $f3=[$2], EXPR$1=[$3], 
$g_0_f_2=[AND(=($4, 0), IS TRUE($2))], $g_3=[=($4, 3)])
+      LogicalAggregate(group=[{0, 2, 3}], groups=[[{0, 2, 3}, {0}]], 
EXPR$1=[SUM($1)], $g=[GROUPING($0, $2, $3)])
+        LogicalProject(DEPTNO=[$7], COMM=[$6], SAL=[$5], $f3=[>($5, 1000)])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
     <TestCase name="testEmptyAggregate">
         <Resource name="sql">
             <![CDATA[select sum(empno) from emp where false group by deptno]]>

Reply via email to