This is an automated email from the ASF dual-hosted git repository.
danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/master by this push:
new 7bb120d [CALCITE-3721] Filter of distinct aggregate call is lost
after applying AggregateExpandDistinctAggregatesRule (Shuo Cheng)
7bb120d is described below
commit 7bb120dc074fb20ea5802f2a574dd141efebbcb9
Author: shuo.cs <[email protected]>
AuthorDate: Fri Jan 10 12:31:42 2020 +0800
[CALCITE-3721] Filter of distinct aggregate call is lost after applying
AggregateExpandDistinctAggregatesRule (Shuo Cheng)
* Generates boolean input field for distinct agg with filter
* Remove useless exprs in projects
close apache/calcite#1758
---
.../AggregateExpandDistinctAggregatesRule.java | 74 +++++++++++++++-------
.../org/apache/calcite/test/RelOptRulesTest.java | 29 +++++++++
.../org/apache/calcite/test/RelOptRulesTest.xml | 69 ++++++++++++++++++++
3 files changed, 148 insertions(+), 24 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
index 1ba326f..57c50cb 100644
---
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
+++
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java
@@ -16,12 +16,10 @@
*/
package org.apache.calcite.rel.rules;
-import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelCollations;
-import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Aggregate.Group;
import org.apache.calcite.rel.core.AggregateCall;
@@ -403,15 +401,30 @@ public final class AggregateExpandDistinctAggregatesRule
extends RelOptRule {
Aggregate aggregate) {
final Set<ImmutableBitSet> groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
+ // GroupSet to distinct filter arg map,
+ // filterArg will be -1 for non-distinct agg call.
+
+ // Using `Set` here because it's possible that two agg calls
+ // have different filterArgs but same groupSet.
+ final Map<ImmutableBitSet, Set<Integer>> distinctFilterArgMap = new
HashMap<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
+ ImmutableBitSet groupSet;
+ int filterArg;
if (!aggCall.isDistinct()) {
+ filterArg = -1;
+ groupSet = aggregate.getGroupSet();
groupSetTreeSet.add(aggregate.getGroupSet());
} else {
- groupSetTreeSet.add(
+ filterArg = aggCall.filterArg;
+ groupSet =
ImmutableBitSet.of(aggCall.getArgList())
- .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
- .union(aggregate.getGroupSet()));
+ .setIf(filterArg, filterArg >= 0)
+ .union(aggregate.getGroupSet());
+ groupSetTreeSet.add(groupSet);
}
+ Set<Integer> filterList = distinctFilterArgMap
+ .computeIfAbsent(groupSet, g -> new HashSet<>());
+ filterList.add(filterArg);
}
final ImmutableList<ImmutableBitSet> groupSets =
@@ -433,38 +446,52 @@ public final class AggregateExpandDistinctAggregatesRule
extends RelOptRule {
relBuilder.push(aggregate.getInput());
final int groupCount = fullGroupSet.cardinality();
- final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
- final int z = groupCount + distinctAggCalls.size();
+ // Get the base ordinal of filter args for different groupSets.
+ final Map<Pair<ImmutableBitSet, Integer>, Integer> filters = new
LinkedHashMap<>();
+ int z = groupCount + distinctAggCalls.size();
+ for (ImmutableBitSet groupSet: groupSets) {
+ Set<Integer> filterArgList = distinctFilterArgMap.get(groupSet);
+ for (Integer filterArg: filterArgList) {
+ filters.put(Pair.of(groupSet, filterArg), z);
+ z += 1;
+ }
+ }
+
distinctAggCalls.add(
AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false,
ImmutableIntList.copyOf(fullGroupSet), -1, RelCollations.EMPTY,
groupSets.size(), relBuilder.peek(), null, "$g"));
- for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
- filters.put(groupSet.e, z + groupSet.i);
- }
relBuilder.aggregate(
relBuilder.groupKey(fullGroupSet,
(Iterable<ImmutableBitSet>) groupSets),
distinctAggCalls);
- final RelNode distinct = relBuilder.peek();
// GROUPING returns an integer (0 or 1). Add a project to convert those
// values to BOOLEAN.
if (!filters.isEmpty()) {
final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
final RexNode nodeZ = nodes.remove(nodes.size() - 1);
- for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
- final long v = groupValue(fullGroupSet, entry.getKey());
+ for (Map.Entry<Pair<ImmutableBitSet, Integer>, Integer> entry :
filters.entrySet()) {
+ final long v = groupValue(fullGroupSet, entry.getKey().left);
+ int distinctFilterArg = remap(fullGroupSet, entry.getKey().right);
+ RexNode expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
+ if (distinctFilterArg > -1) {
+ // 'AND' the filter of the distinct aggregate call and the group
value.
+ expr = relBuilder.and(expr,
+ relBuilder.call(SqlStdOperatorTable.IS_TRUE,
+ relBuilder.field(distinctFilterArg)));
+ }
+ // "f" means filter.
nodes.add(
- relBuilder.alias(
- relBuilder.equals(nodeZ, relBuilder.literal(v)),
- "$g_" + v));
+ relBuilder.alias(expr,
+ "$g_" + v + (distinctFilterArg < 0 ? "" : "_f_" +
distinctFilterArg)));
}
relBuilder.project(nodes);
}
int x = groupCount;
+ final ImmutableBitSet groupSet = aggregate.getGroupSet();
final List<AggregateCall> newCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
@@ -473,27 +500,26 @@ public final class AggregateExpandDistinctAggregatesRule
extends RelOptRule {
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
- newFilterArg = filters.get(aggregate.getGroupSet());
+ newFilterArg = filters.get(Pair.of(groupSet, -1));
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
- newFilterArg =
- filters.get(
- ImmutableBitSet.of(aggCall.getArgList())
- .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
- .union(aggregate.getGroupSet()));
+ final ImmutableBitSet newGroupSet =
ImmutableBitSet.of(aggCall.getArgList())
+ .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+ .union(groupSet);
+ newFilterArg = filters.get(Pair.of(newGroupSet, aggCall.filterArg));
}
final AggregateCall newCall =
AggregateCall.create(aggregation, false,
aggCall.isApproximate(), aggCall.ignoreNulls(),
newArgList, newFilterArg, aggCall.collation,
- aggregate.getGroupCount(), distinct, null, aggCall.name);
+ aggregate.getGroupCount(), relBuilder.peek(), null,
aggCall.name);
newCalls.add(newCall);
}
relBuilder.aggregate(
relBuilder.groupKey(
- remap(fullGroupSet, aggregate.getGroupSet()),
+ remap(fullGroupSet, groupSet),
(Iterable<ImmutableBitSet>)
remap(fullGroupSet, aggregate.getGroupSets())),
newCalls);
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 7b614e0..01bcff8 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -1482,6 +1482,35 @@ public class RelOptRulesTest extends RelOptTestBase {
sql(sql).with(program).check();
}
+ @Test public void testDistinctWithFilterWithoutGroupBy() {
+ final String sql = "SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE
sal > 1000)\n"
+ + "FROM emp";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(AggregateExpandDistinctAggregatesRule.INSTANCE)
+ .build();
+ sql(sql).with(program).check();
+ }
+
+ @Test public void testDistinctWithDiffFiltersAndSameGroupSet() {
+ final String sql = "SELECT COUNT(DISTINCT c) FILTER (WHERE d),\n"
+ + "COUNT(DISTINCT d) FILTER (WHERE c)\n"
+ + "FROM (select sal > 1000 is true as c, sal < 500 is true as d, comm
from emp)";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(AggregateExpandDistinctAggregatesRule.INSTANCE)
+ .build();
+ sql(sql).with(program).check();
+ }
+
+ @Test public void testDistinctWithFilterAndGroupBy() {
+ final String sql = "SELECT deptno, SUM(comm), COUNT(DISTINCT sal) FILTER
(WHERE sal > 1000)\n"
+ + "FROM emp\n"
+ + "GROUP BY deptno";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(AggregateExpandDistinctAggregatesRule.INSTANCE)
+ .build();
+ sql(sql).with(program).check();
+ }
+
@Test public void testPushProjectPastFilter() {
final String sql = "select empno + deptno from emp where sal = 10 * comm\n"
+ "and upper(ename) = 'FOO'";
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 63016e1..695c171 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -1153,6 +1153,75 @@ LogicalAggregate(group=[{0}], EXPR$1=[SUM($3)],
EXPR$2=[MIN($4)], EXPR$3=[COUNT(
]]>
</Resource>
</TestCase>
+ <TestCase name="testDistinctWithFilterWithoutGroupBy">
+ <Resource name="sql">
+ <![CDATA[SELECT SUM(comm), COUNT(DISTINCT sal) FILTER (WHERE sal >
1000)
+FROM emp]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0)], EXPR$1=[COUNT(DISTINCT $1)
FILTER $2])
+ LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[MIN($2) FILTER $4], EXPR$1=[COUNT($0)
FILTER $3])
+ LogicalProject(SAL=[$0], $f2=[$1], EXPR$0=[$2], $g_0_f_1=[AND(=($3, 0), IS
TRUE($1))], $g_3=[=($3, 3)])
+ LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)],
$g=[GROUPING($1, $2)])
+ LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDistinctWithDiffFiltersAndSameGroupSet">
+ <Resource name="sql">
+ <![CDATA[SELECT COUNT(DISTINCT c) FILTER (WHERE d),
+COUNT(DISTINCT d) FILTER (WHERE c)
+FROM (select sal > 1000 is true as c, sal < 500 is true as d from emp)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT(DISTINCT $0) FILTER $1],
EXPR$1=[COUNT(DISTINCT $1) FILTER $0])
+ LogicalProject(C=[>($5, 1000)], D=[<($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT($0) FILTER $3], EXPR$1=[COUNT($1)
FILTER $2])
+ LogicalProject(C=[$0], D=[$1], $g_0_f_0=[AND(=($2, 0), $0)],
$g_0_f_1=[AND(=($2, 0), $1)])
+ LogicalAggregate(group=[{0, 1}], $g=[GROUPING($0, $1)])
+ LogicalProject(C=[>($5, 1000)], D=[<($5, 500)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDistinctWithFilterAndGroupBy">
+ <Resource name="sql">
+ <![CDATA[SELECT deptno, SUM(comm), COUNT(DISTINCT sal) FILTER
(WHERE sal > 1000)
+FROM emp
+GROUP BY deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)], EXPR$2=[COUNT(DISTINCT $2)
FILTER $3])
+ LogicalProject(DEPTNO=[$7], COMM=[$6], SAL=[$5], $f3=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CAST($1):INTEGER NOT NULL], EXPR$2=[$2])
+ LogicalAggregate(group=[{0}], EXPR$1=[MIN($3) FILTER $5], EXPR$2=[COUNT($1)
FILTER $4])
+ LogicalProject(DEPTNO=[$0], SAL=[$1], $f3=[$2], EXPR$1=[$3],
$g_0_f_2=[AND(=($4, 0), IS TRUE($2))], $g_3=[=($4, 3)])
+ LogicalAggregate(group=[{0, 2, 3}], groups=[[{0, 2, 3}, {0}]],
EXPR$1=[SUM($1)], $g=[GROUPING($0, $2, $3)])
+ LogicalProject(DEPTNO=[$7], COMM=[$6], SAL=[$5], $f3=[>($5, 1000)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testEmptyAggregate">
<Resource name="sql">
<![CDATA[select sum(empno) from emp where false group by deptno]]>