Repository: calcite
Updated Branches:
  refs/heads/master 483c0a61b -> 73a09a691


[CALCITE-2195] AggregateJoinTransposeRule fails to aggregate over unique column 
(Zhong Yu)

Close apache/calcite#637


Project: http://git-wip-us.apache.org/repos/asf/calcite/repo
Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/73a09a69
Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/73a09a69
Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/73a09a69

Branch: refs/heads/master
Commit: 73a09a691a67eec46349c94b21cf3fa483775c02
Parents: 483c0a6
Author: yuzhong <yuzhong...@alibaba-inc.com>
Authored: Tue Feb 27 22:41:14 2018 +0800
Committer: Julian Hyde <jh...@apache.org>
Committed: Fri Mar 2 10:26:12 2018 -0800

----------------------------------------------------------------------
 .../rel/rules/AggregateJoinTransposeRule.java   | 35 ++++++++++++++---
 .../apache/calcite/test/RelOptRulesTest.java    | 17 +++++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml | 40 +++++++++++++++++---
 3 files changed, 81 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/calcite/blob/73a09a69/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
index d7c86aa..1068702 100644
--- 
a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
@@ -197,6 +197,11 @@ public class AggregateJoinTransposeRule extends RelOptRule 
{
       for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
         map.put(c.e, belowOffset + c.i);
       }
+      final Mappings.TargetMapping mapping =
+          s == 0
+              ? Mappings.createIdentity(fieldCount)
+              : Mappings.createShiftMapping(fieldCount + offset, 0, offset,
+                  fieldCount);
       final ImmutableBitSet belowAggregateKey =
           belowAggregateKeyNotShifted.shift(-offset);
       final boolean unique;
@@ -224,17 +229,35 @@ public class AggregateJoinTransposeRule extends 
RelOptRule {
       if (unique) {
         ++uniqueCount;
         side.aggregate = false;
-        side.newInput = joinInput;
+        relBuilder.push(joinInput);
+        final List<RexNode> projects = new ArrayList<>();
+        for (Integer i : belowAggregateKey) {
+          projects.add(relBuilder.field(i));
+        }
+        for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) 
{
+          final SqlAggFunction aggregation = aggCall.e.getAggregation();
+          final SqlSplittableAggFunction splitter =
+              Preconditions.checkNotNull(
+                  aggregation.unwrap(SqlSplittableAggFunction.class));
+          if (!aggCall.e.getArgList().isEmpty()
+              && 
fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
+            final RexNode singleton = splitter.singleton(rexBuilder,
+                joinInput.getRowType(), aggCall.e.transform(mapping));
+            if (singleton instanceof RexInputRef) {
+              side.split.put(aggCall.i, ((RexInputRef) singleton).getIndex());
+            } else {
+              projects.add(singleton);
+              side.split.put(aggCall.i, projects.size() - 1);
+            }
+          }
+        }
+        relBuilder.project(projects);
+        side.newInput = relBuilder.build();
       } else {
         side.aggregate = true;
         List<AggregateCall> belowAggCalls = new ArrayList<>();
         final SqlSplittableAggFunction.Registry<AggregateCall>
             belowAggCallRegistry = registry(belowAggCalls);
-        final Mappings.TargetMapping mapping =
-            s == 0
-                ? Mappings.createIdentity(fieldCount)
-                : Mappings.createShiftMapping(fieldCount + offset, 0, offset,
-                    fieldCount);
         final int oldGroupKeyCount = aggregate.getGroupCount();
         final int newGroupKeyCount = belowAggregateKey.cardinality();
         for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) 
{

http://git-wip-us.apache.org/repos/asf/calcite/blob/73a09a69/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 2e2a5f6..cc995e6 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -3010,6 +3010,23 @@ public class RelOptRulesTest extends RelOptTestBase {
     sql(sql).withPre(preProgram).with(program).check();
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-2195";>[CALCITE-2195]
+   * AggregateJoinTransposeRule fails to aggregate over unique column</a>. */
+  @Test public void testPushAggregateThroughJoin6() {
+    final HepProgram preProgram = new HepProgramBuilder()
+        .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
+        .build();
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
+        .build();
+    final String sql = "select sum(B.sal)\n"
+        + "from sales.emp as A\n"
+        + "join (select distinct sal from sales.emp) as B\n"
+        + "on A.sal=B.sal\n";
+    sql(sql).withPre(preProgram).with(program).check();
+  }
+
   /** SUM is the easiest aggregate function to split. */
   @Test public void testPushAggregateSumThroughJoin() {
     final HepProgram preProgram = new HepProgramBuilder()

http://git-wip-us.apache.org/repos/asf/calcite/blob/73a09a69/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
----------------------------------------------------------------------
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 70e54c8..3016285 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -5806,7 +5806,8 @@ LogicalProject(DEPTNO=[$0])
   LogicalJoin(condition=[=($0, $1)], joinType=[inner])
     LogicalAggregate(group=[{7}])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
-    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+    LogicalProject(DEPTNO=[$0])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
         </Resource>
     </TestCase>
@@ -5828,14 +5829,43 @@ LogicalProject(DEPTNO=[$0], DEPTNO0=[$1])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(DEPTNO=[$0], DEPTNO0=[$1])
-  LogicalProject(DEPTNO=[$0], DEPTNO0=[$1])
-    LogicalJoin(condition=[=($0, $1)], joinType=[inner])
-      LogicalAggregate(group=[{7}])
-        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+  LogicalJoin(condition=[=($0, $1)], joinType=[inner])
+    LogicalAggregate(group=[{7}])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalProject(DEPTNO=[$0])
       LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testPushAggregateThroughJoin6">
+        <Resource name="sql">
+            <![CDATA[select sum(B.sal)
+from sales.emp as A
+join (select distinct sal from sales.emp) as B
+on A.sal=B.sal
+]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($9)])
+  LogicalJoin(condition=[=($5, $9)], joinType=[inner])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{5}])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[SUM($3)])
+  LogicalProject(SAL=[$0], $f1=[$1], SAL0=[$2], $f3=[CAST(*($1, $2)):INTEGER])
+    LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+      LogicalAggregate(group=[{5}], agg#0=[COUNT()])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalAggregate(group=[{5}])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
     <TestCase name="testPushAggregateSumThroughJoin">
         <Resource name="sql">
             <![CDATA[select e.job,sum(sal)

Reply via email to