Repository: flink
Updated Branches:
  refs/heads/master 8bcb2ae3c -> 1a062b796


[FLINK-3475] [table] Add support for DISTINCT aggregates in SQL queries.

This closes #3111.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/36c9348f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/36c9348f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/36c9348f

Branch: refs/heads/master
Commit: 36c9348ff06cae1fe55925bcc6081154be2f10f5
Parents: 8bcb2ae
Author: Zhenghua Gao <[email protected]>
Authored: Thu Jan 12 10:33:27 2017 +0800
Committer: Fabian Hueske <[email protected]>
Committed: Mon Feb 27 22:50:13 2017 +0100

----------------------------------------------------------------------
 docs/dev/table_api.md                           |    3 +-
 ...nkAggregateExpandDistinctAggregatesRule.java | 1152 ++++++++++++++++++
 .../flink/table/plan/rules/FlinkRuleSets.scala  |    5 +-
 .../rules/dataSet/DataSetAggregateRule.scala    |    3 -
 .../DataSetAggregateWithNullValuesRule.scala    |    3 -
 .../scala/batch/sql/AggregationsITCase.scala    |   27 +-
 .../scala/batch/sql/DistinctAggregateTest.scala |  476 ++++++++
 .../batch/sql/QueryDecorrelationTest.scala      |    2 +-
 8 files changed, 1651 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 22fd636..1b43b7c 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -1324,13 +1324,12 @@ val result = tableEnv.sql(
 
 #### Limitations
 
-The current version supports selection (filter), projection, inner equi-joins, 
grouping, non-distinct aggregates, and sorting on batch tables.
+The current version supports selection (filter), projection, inner equi-joins, 
grouping, aggregates, and sorting on batch tables.
 
 Among others, the following SQL features are not supported, yet:
 
 - Timestamps and intervals are limited to milliseconds precision
 - Interval arithmetic is currenly limited
-- Distinct aggregates (e.g., `COUNT(DISTINCT name)`)
 - Non-equi joins and Cartesian products
 - Efficient grouping sets
 

http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java
 
b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java
new file mode 100644
index 0000000..d7b1ffa
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateExpandDistinctAggregatesRule.java
@@ -0,0 +1,1152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.calcite.rules;
+
+import org.apache.calcite.plan.Contexts;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.fun.SqlCountAggFunction;
+import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.fun.SqlSumAggFunction;
+import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.Util;
+
+import org.apache.flink.util.Preconditions;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedSet;
+import java.util.TreeSet;
+
+/**
+ Copy calcite's AggregateExpandDistinctAggregatesRule to Flink project,
+ and do a quick fix to avoid some bad case mentioned in CALCITE-1558.
+ Should drop it and use calcite's AggregateExpandDistinctAggregatesRule
+ when we upgrade to calcite 1.12(above)
+ */
+
+/**
+ * Planner rule that expands distinct aggregates
+ * (such as {@code COUNT(DISTINCT x)}) from a
+ * {@link org.apache.calcite.rel.logical.LogicalAggregate}.
+ *
+ * <p>How this is done depends upon the arguments to the function. If all
+ * functions have the same argument
+ * (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
+ * {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
+ * sufficient.
+ *
+ * <p>If there are multiple arguments
+ * (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
+ * the rule creates separate {@code Aggregate}s and combines using a
+ * {@link org.apache.calcite.rel.core.Join}.
+ */
+public final class FlinkAggregateExpandDistinctAggregatesRule extends 
RelOptRule {
+       //~ Static fields/initializers 
---------------------------------------------
+
+       /** The default instance of the rule; operates only on logical 
expressions. */
+       public static final FlinkAggregateExpandDistinctAggregatesRule INSTANCE 
=
+                       new 
FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, true,
+                                       RelFactories.LOGICAL_BUILDER);
+
+       /** Instance of the rule that operates only on logical expressions and
+        * generates a join. */
+       public static final FlinkAggregateExpandDistinctAggregatesRule JOIN =
+                       new 
FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, false,
+                                       RelFactories.LOGICAL_BUILDER);
+
+       private static final BigDecimal TWO = BigDecimal.valueOf(2L);
+
+       public final boolean useGroupingSets;
+
+       //~ Constructors 
-----------------------------------------------------------
+
+       public FlinkAggregateExpandDistinctAggregatesRule(
+                       Class<? extends LogicalAggregate> clazz,
+                       boolean useGroupingSets,
+                       RelBuilderFactory relBuilderFactory) {
+               super(operand(clazz, any()), relBuilderFactory, null);
+               this.useGroupingSets = useGroupingSets;
+       }
+
+       @Deprecated // to be removed before 2.0
+       public FlinkAggregateExpandDistinctAggregatesRule(
+                       Class<? extends LogicalAggregate> clazz,
+                       boolean useGroupingSets,
+                       RelFactories.JoinFactory joinFactory) {
+               this(clazz, useGroupingSets, 
RelBuilder.proto(Contexts.of(joinFactory)));
+       }
+
+       @Deprecated // to be removed before 2.0
+       public FlinkAggregateExpandDistinctAggregatesRule(
+                       Class<? extends LogicalAggregate> clazz,
+                       RelFactories.JoinFactory joinFactory) {
+               this(clazz, false, RelBuilder.proto(Contexts.of(joinFactory)));
+       }
+
+       //~ Methods 
----------------------------------------------------------------
+
+       public void onMatch(RelOptRuleCall call) {
+               final Aggregate aggregate = call.rel(0);
+               if (!aggregate.containsDistinctCall()) {
+                       return;
+               }
+
+               // Find all of the agg expressions. We use a LinkedHashSet to 
ensure
+               // determinism.
+               int nonDistinctCount = 0;
+               int distinctCount = 0;
+               int filterCount = 0;
+               int unsupportedAggCount = 0;
+               final Set<Pair<List<Integer>, Integer>> argLists = new 
LinkedHashSet<>();
+               for (AggregateCall aggCall : aggregate.getAggCallList()) {
+                       if (aggCall.filterArg >= 0) {
+                               ++filterCount;
+                       }
+                       if (!aggCall.isDistinct()) {
+                               ++nonDistinctCount;
+                               if (!(aggCall.getAggregation() instanceof 
SqlCountAggFunction
+                                               || aggCall.getAggregation() 
instanceof SqlSumAggFunction
+                                               || aggCall.getAggregation() 
instanceof SqlMinMaxAggFunction)) {
+                                       ++unsupportedAggCount;
+                               }
+                               continue;
+                       }
+                       ++distinctCount;
+                       argLists.add(Pair.of(aggCall.getArgList(), 
aggCall.filterArg));
+               }
+               Preconditions.checkState(argLists.size() > 0, 
"containsDistinctCall lied");
+
+               // If all of the agg expressions are distinct and have the same
+               // arguments then we can use a more efficient form.
+               if (nonDistinctCount == 0 && argLists.size() == 1) {
+                       final Pair<List<Integer>, Integer> pair =
+                                       Iterables.getOnlyElement(argLists);
+                       final RelBuilder relBuilder = call.builder();
+                       convertMonopole(relBuilder, aggregate, pair.left, 
pair.right);
+                       call.transformTo(relBuilder.build());
+                       return;
+               }
+
+               if (useGroupingSets) {
+                       rewriteUsingGroupingSets(call, aggregate, argLists);
+                       return;
+               }
+
+               // If only one distinct aggregate and one or more non-distinct 
aggregates,
+               // we can generate multi-phase aggregates
+               if (distinctCount == 1 // one distinct aggregate
+                               && filterCount == 0 // no filter
+                               && unsupportedAggCount == 0 // 
sum/min/max/count in non-distinct aggregate
+                               && nonDistinctCount > 0) { // one or more 
non-distinct aggregates
+                       final RelBuilder relBuilder = call.builder();
+                       convertSingletonDistinct(relBuilder, aggregate, 
argLists);
+                       call.transformTo(relBuilder.build());
+                       return;
+               }
+
+               // Create a list of the expressions which will yield the final 
result.
+               // Initially, the expressions point to the input field.
+               final List<RelDataTypeField> aggFields =
+                               aggregate.getRowType().getFieldList();
+               final List<RexInputRef> refs = new ArrayList<>();
+               final List<String> fieldNames = 
aggregate.getRowType().getFieldNames();
+               final ImmutableBitSet groupSet = aggregate.getGroupSet();
+               final int groupAndIndicatorCount =
+                               aggregate.getGroupCount() + 
aggregate.getIndicatorCount();
+               for (int i : Util.range(groupAndIndicatorCount)) {
+                       refs.add(RexInputRef.of(i, aggFields));
+               }
+
+               // Aggregate the original relation, including any non-distinct 
aggregates.
+               final List<AggregateCall> newAggCallList = new ArrayList<>();
+               int i = -1;
+               for (AggregateCall aggCall : aggregate.getAggCallList()) {
+                       ++i;
+                       if (aggCall.isDistinct()) {
+                               refs.add(null);
+                               continue;
+                       }
+                       refs.add(
+                                       new RexInputRef(
+                                                       groupAndIndicatorCount 
+ newAggCallList.size(),
+                                                       
aggFields.get(groupAndIndicatorCount + i).getType()));
+                       newAggCallList.add(aggCall);
+               }
+
+               // In the case where there are no non-distinct aggregates 
(regardless of
+               // whether there are group bys), there's no need to generate the
+               // extra aggregate and join.
+               final RelBuilder relBuilder = call.builder();
+               relBuilder.push(aggregate.getInput());
+               int n = 0;
+               if (!newAggCallList.isEmpty()) {
+                       final RelBuilder.GroupKey groupKey =
+                                       relBuilder.groupKey(groupSet, 
aggregate.indicator, aggregate.getGroupSets());
+                       relBuilder.aggregate(groupKey, newAggCallList);
+                       ++n;
+               }
+
+               // For each set of operands, find and rewrite all calls which 
have that
+               // set of operands.
+               for (Pair<List<Integer>, Integer> argList : argLists) {
+                       doRewrite(relBuilder, aggregate, n++, argList.left, 
argList.right, refs);
+               }
+
+               relBuilder.project(refs, fieldNames);
+               call.transformTo(relBuilder.build());
+       }
+
+       /**
+        * Converts an aggregate with one distinct aggregate and one or more
+        * non-distinct aggregates to multi-phase aggregates (see reference 
example
+        * below).
+        *
+        * @param relBuilder Contains the input relational expression
+        * @param aggregate  Original aggregate
+        * @param argLists   Arguments and filters to the distinct aggregate 
function
+        *
+        */
+       private RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
+                                                                               
        Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
+               // For example,
+               //      SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
+               //      FROM emp
+               //      GROUP BY deptno
+               //
+               // becomes
+               //
+               //      SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
+               //      FROM (
+               //                SELECT deptno, COUNT(*) as cnt, SUM(bonus), 
sal
+               //                FROM EMP
+               //                GROUP BY deptno, sal)                 // 
Aggregate B
+               //      GROUP BY deptno                                         
// Aggregate A
+               relBuilder.push(aggregate.getInput());
+               final List<Pair<RexNode, String>> projects = new ArrayList<>();
+               final Map<Integer, Integer> sourceOf = new HashMap<>();
+               SortedSet<Integer> newGroupSet = new TreeSet<>();
+               final List<RelDataTypeField> childFields =
+                               relBuilder.peek().getRowType().getFieldList();
+               final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
+
+               SortedSet<Integer> groupSet = new 
TreeSet<>(aggregate.getGroupSet().asList());
+
+               // Add the distinct aggregate column(s) to the group-by columns,
+               // if not already a part of the group-by
+               newGroupSet.addAll(aggregate.getGroupSet().asList());
+               for (Pair<List<Integer>, Integer> argList : argLists) {
+                       newGroupSet.addAll(argList.getKey());
+               }
+
+               // Re-map the arguments to the aggregate A. These arguments 
will get
+               // remapped because of the intermediate aggregate B generated 
as part of the
+               // transformation.
+               for (int arg : newGroupSet) {
+                       sourceOf.put(arg, projects.size());
+                       projects.add(RexInputRef.of2(arg, childFields));
+               }
+               // Generate the intermediate aggregate B
+               final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+               final List<AggregateCall> newAggCalls = new ArrayList<>();
+               final List<Integer> fakeArgs = new ArrayList<>();
+               final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
+               // First identify the real arguments, then use the rest for 
fake arguments
+               // e.g. if real arguments are 0, 1, 3. Then the fake arguments 
will be 2, 4
+               for (final AggregateCall aggCall : aggCalls) {
+                       if (!aggCall.isDistinct()) {
+                               for (int arg : aggCall.getArgList()) {
+                                       if (!groupSet.contains(arg)) {
+                                               sourceOf.put(arg, 
projects.size());
+                                       }
+                               }
+                       }
+               }
+               int fakeArg0 = 0;
+               for (final AggregateCall aggCall : aggCalls) {
+                       // We will deal with non-distinct aggregates below
+                       if (!aggCall.isDistinct()) {
+                               boolean isGroupKeyUsedInAgg = false;
+                               for (int arg : aggCall.getArgList()) {
+                                       if (groupSet.contains(arg)) {
+                                               isGroupKeyUsedInAgg = true;
+                                               break;
+                                       }
+                               }
+                               if (aggCall.getArgList().size() == 0 || 
isGroupKeyUsedInAgg) {
+                                       while (sourceOf.get(fakeArg0) != null) {
+                                               ++fakeArg0;
+                                       }
+                                       fakeArgs.add(fakeArg0);
+                                       ++fakeArg0;
+                               }
+                       }
+               }
+               for (final AggregateCall aggCall : aggCalls) {
+                       if (!aggCall.isDistinct()) {
+                               for (int arg : aggCall.getArgList()) {
+                                       if (!groupSet.contains(arg)) {
+                                               sourceOf.remove(arg);
+                                       }
+                               }
+                       }
+               }
+               // Compute the remapped arguments using fake arguments for 
non-distinct
+               // aggregates with no arguments e.g. count(*).
+               int fakeArgIdx = 0;
+               for (final AggregateCall aggCall : aggCalls) {
+                       // Project the column corresponding to the distinct 
aggregate. Project
+                       // as-is all the non-distinct aggregates
+                       if (!aggCall.isDistinct()) {
+                               final AggregateCall newCall =
+                                               
AggregateCall.create(aggCall.getAggregation(), false,
+                                                               
aggCall.getArgList(), -1,
+                                                               
ImmutableBitSet.of(newGroupSet).cardinality(),
+                                                               
relBuilder.peek(), null, aggCall.name);
+                               newAggCalls.add(newCall);
+                               if (newCall.getArgList().size() == 0) {
+                                       int fakeArg = fakeArgs.get(fakeArgIdx);
+                                       callArgMap.put(newCall, fakeArg);
+                                       sourceOf.put(fakeArg, projects.size());
+                                       projects.add(
+                                                       Pair.of((RexNode) new 
RexInputRef(fakeArg, newCall.getType()),
+                                                                       
newCall.getName()));
+                                       ++fakeArgIdx;
+                               } else {
+                                       for (int arg : newCall.getArgList()) {
+                                               if (groupSet.contains(arg)) {
+                                                       int fakeArg = 
fakeArgs.get(fakeArgIdx);
+                                                       callArgMap.put(newCall, 
fakeArg);
+                                                       sourceOf.put(fakeArg, 
projects.size());
+                                                       projects.add(
+                                                                       
Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
+                                                                               
        newCall.getName()));
+                                                       ++fakeArgIdx;
+                                               } else {
+                                                       sourceOf.put(arg, 
projects.size());
+                                                       projects.add(
+                                                                       
Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
+                                                                               
        newCall.getName()));
+                                               }
+                                       }
+                               }
+                       }
+               }
+               // Generate the aggregate B (see the reference example above)
+               relBuilder.push(
+                               aggregate.copy(
+                                               aggregate.getTraitSet(), 
relBuilder.build(),
+                                               false, 
ImmutableBitSet.of(newGroupSet), null, newAggCalls));
+               // Convert the existing aggregate to aggregate A (see the 
reference example above)
+               final List<AggregateCall> newTopAggCalls =
+                               Lists.newArrayList(aggregate.getAggCallList());
+               // Use the remapped arguments for the (non)distinct aggregate 
calls
+               for (int i = 0; i < newTopAggCalls.size(); i++) {
+                       // Re-map arguments.
+                       final AggregateCall aggCall = newTopAggCalls.get(i);
+                       final int argCount = aggCall.getArgList().size();
+                       final List<Integer> newArgs = new ArrayList<>(argCount);
+                       final AggregateCall newCall;
+
+
+                       for (int j = 0; j < argCount; j++) {
+                               final Integer arg = aggCall.getArgList().get(j);
+                               if (callArgMap.containsKey(aggCall)) {
+                                       
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+                               }
+                               else {
+                                       newArgs.add(sourceOf.get(arg));
+                               }
+                       }
+                       if (aggCall.isDistinct()) {
+                               newCall =
+                                               
AggregateCall.create(aggCall.getAggregation(), false, newArgs,
+                                                               -1, 
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+                                                               
aggCall.getType(), aggCall.name);
+                       } else {
+                               // If aggregate B had a COUNT aggregate call 
the corresponding aggregate at
+                               // aggregate A must be SUM. For other 
aggregates, it remains the same.
+                               if (aggCall.getAggregation() instanceof 
SqlCountAggFunction) {
+                                       if (aggCall.getArgList().size() == 0) {
+                                               
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+                                       }
+                                       if (hasGroupBy) {
+                                               SqlSumAggFunction sumAgg = new 
SqlSumAggFunction(null);
+                                               newCall =
+                                                               
AggregateCall.create(sumAgg, false, newArgs, -1,
+                                                                               
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+                                                                               
aggCall.getType(), aggCall.getName());
+                                       } else {
+                                               SqlSumEmptyIsZeroAggFunction 
sumAgg = new SqlSumEmptyIsZeroAggFunction();
+                                               newCall =
+                                                               
AggregateCall.create(sumAgg, false, newArgs, -1,
+                                                                               
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+                                                                               
aggCall.getType(), aggCall.getName());
+                                       }
+                               } else {
+                                       newCall =
+                                                       
AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
+                                                                       
aggregate.getGroupSet().cardinality(),
+                                                                       
relBuilder.peek(), aggCall.getType(), aggCall.name);
+                               }
+                       }
+                       newTopAggCalls.set(i, newCall);
+               }
+               // Populate the group-by keys with the remapped arguments for 
aggregate A
+               newGroupSet.clear();
+               for (int arg : aggregate.getGroupSet()) {
+                       newGroupSet.add(sourceOf.get(arg));
+               }
+               relBuilder.push(
+                               aggregate.copy(aggregate.getTraitSet(),
+                                               relBuilder.build(), 
aggregate.indicator,
+                                               
ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
+               return relBuilder;
+       }
+       /*
+       public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
+                                                                               
           Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
+               // For example,
+               //      SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
+               //      FROM emp
+               //      GROUP BY deptno
+               //
+               // becomes
+               //
+               //      SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
+               //      FROM (
+               //                SELECT deptno, COUNT(*) as cnt, SUM(bonus), 
sal
+               //                FROM EMP
+               //                GROUP BY deptno, sal)                 // 
Aggregate B
+               //      GROUP BY deptno                                         
// Aggregate A
+               relBuilder.push(aggregate.getInput());
+               final List<Pair<RexNode, String>> projects = new ArrayList<>();
+               final Map<Integer, Integer> sourceOf = new HashMap<>();
+               SortedSet<Integer> newGroupSet = new TreeSet<>();
+               final List<RelDataTypeField> childFields =
+                               relBuilder.peek().getRowType().getFieldList();
+               final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
+
+               // Add the distinct aggregate column(s) to the group-by columns,
+               // if not already a part of the group-by
+               newGroupSet.addAll(aggregate.getGroupSet().asList());
+               for (Pair<List<Integer>, Integer> argList : argLists) {
+                       newGroupSet.addAll(argList.getKey());
+               }
+
+               // Re-map the arguments to the aggregate A. These arguments 
will get
+               // remapped because of the intermediate aggregate B generated 
as part of the
+               // transformation.
+               for (int arg : newGroupSet) {
+                       sourceOf.put(arg, projects.size());
+                       projects.add(RexInputRef.of2(arg, childFields));
+               }
+               // Generate the intermediate aggregate B
+               final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+               final List<AggregateCall> newAggCalls = new ArrayList<>();
+               final List<Integer> fakeArgs = new ArrayList<>();
+               final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
+               // First identify the real arguments, then use the rest for 
fake arguments
+               // e.g. if real arguments are 0, 1, 3. Then the fake arguments 
will be 2, 4
+               for (final AggregateCall aggCall : aggCalls) {
+                       if (!aggCall.isDistinct()) {
+                               for (int arg : aggCall.getArgList()) {
+                                       if (!sourceOf.containsKey(arg)) {
+                                               sourceOf.put(arg, 
projects.size());
+                                       }
+                               }
+                       }
+               }
+               int fakeArg0 = 0;
+               for (final AggregateCall aggCall : aggCalls) {
+                       // We will deal with non-distinct aggregates below
+                       if (!aggCall.isDistinct()) {
+                               boolean isGroupKeyUsedInAgg = false;
+                               for (int arg : aggCall.getArgList()) {
+                                       if (sourceOf.containsKey(arg)) {
+                                               isGroupKeyUsedInAgg = true;
+                                               break;
+                                       }
+                               }
+                               if (aggCall.getArgList().size() == 0 || 
isGroupKeyUsedInAgg) {
+                                       while (sourceOf.get(fakeArg0) != null) {
+                                               ++fakeArg0;
+                                       }
+                                       fakeArgs.add(fakeArg0);
+                               }
+                       }
+               }
+               for (final AggregateCall aggCall : aggCalls) {
+                       if (!aggCall.isDistinct()) {
+                               for (int arg : aggCall.getArgList()) {
+                                       if (!sourceOf.containsKey(arg)) {
+                                               sourceOf.remove(arg);
+                                       }
+                               }
+                       }
+               }
+               // Compute the remapped arguments using fake arguments for 
non-distinct
+               // aggregates with no arguments e.g. count(*).
+               int fakeArgIdx = 0;
+               for (final AggregateCall aggCall : aggCalls) {
+                       // Project the column corresponding to the distinct 
aggregate. Project
+                       // as-is all the non-distinct aggregates
+                       if (!aggCall.isDistinct()) {
+                               final AggregateCall newCall =
+                                               
AggregateCall.create(aggCall.getAggregation(), false,
+                                                               
aggCall.getArgList(), -1,
+                                                               
ImmutableBitSet.of(newGroupSet).cardinality(),
+                                                               
relBuilder.peek(), null, aggCall.name);
+                               newAggCalls.add(newCall);
+                               if (newCall.getArgList().size() == 0) {
+                                       int fakeArg = fakeArgs.get(fakeArgIdx);
+                                       callArgMap.put(newCall, fakeArg);
+                                       sourceOf.put(fakeArg, projects.size());
+                                       projects.add(
+                                                       Pair.of((RexNode) new 
RexInputRef(fakeArg, newCall.getType()),
+                                                                       
newCall.getName()));
+                                       ++fakeArgIdx;
+                               } else {
+                                       for (int arg : newCall.getArgList()) {
+                                               if (sourceOf.containsKey(arg)) {
+                                                       int fakeArg = 
fakeArgs.get(fakeArgIdx);
+                                                       callArgMap.put(newCall, 
fakeArg);
+                                                       sourceOf.put(fakeArg, 
projects.size());
+                                                       projects.add(
+                                                                       
Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
+                                                                               
        newCall.getName()));
+                                                       ++fakeArgIdx;
+                                               } else {
+                                                       sourceOf.put(arg, 
projects.size());
+                                                       projects.add(
+                                                                       
Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
+                                                                               
        newCall.getName()));
+                                               }
+                                       }
+                               }
+                       }
+               }
+               // Generate the aggregate B (see the reference example above)
+               relBuilder.push(
+                               aggregate.copy(
+                                               aggregate.getTraitSet(), 
relBuilder.build(),
+                                               false, 
ImmutableBitSet.of(newGroupSet), null, newAggCalls));
+               // Convert the existing aggregate to aggregate A (see the 
reference example above)
+               final List<AggregateCall> newTopAggCalls =
+                               Lists.newArrayList(aggregate.getAggCallList());
+               // Use the remapped arguments for the (non)distinct aggregate 
calls
+               for (int i = 0; i < newTopAggCalls.size(); i++) {
+                       // Re-map arguments.
+                       final AggregateCall aggCall = newTopAggCalls.get(i);
+                       final int argCount = aggCall.getArgList().size();
+                       final List<Integer> newArgs = new ArrayList<>(argCount);
+                       final AggregateCall newCall;
+
+
+                       for (int j = 0; j < argCount; j++) {
+                               final Integer arg = aggCall.getArgList().get(j);
+                               if (callArgMap.containsKey(aggCall)) {
+                                       
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+                               }
+                               else {
+                                       newArgs.add(sourceOf.get(arg));
+                               }
+                       }
+                       if (aggCall.isDistinct()) {
+                               newCall =
+                                               
AggregateCall.create(aggCall.getAggregation(), false, newArgs,
+                                                               -1, 
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+                                                               
aggCall.getType(), aggCall.name);
+                       } else {
+                               // If aggregate B had a COUNT aggregate call 
the corresponding aggregate at
+                               // aggregate A must be SUM. For other 
aggregates, it remains the same.
+                               if (aggCall.getAggregation() instanceof 
SqlCountAggFunction) {
+                                       if (aggCall.getArgList().size() == 0) {
+                                               
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
+                                       }
+                                       if (hasGroupBy) {
+                                               SqlSumAggFunction sumAgg = new 
SqlSumAggFunction(null);
+                                               newCall =
+                                                               
AggregateCall.create(sumAgg, false, newArgs, -1,
+                                                                               
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+                                                                               
aggCall.getType(), aggCall.getName());
+                                       } else {
+                                               SqlSumEmptyIsZeroAggFunction 
sumAgg = new SqlSumEmptyIsZeroAggFunction();
+                                               newCall =
+                                                               
AggregateCall.create(sumAgg, false, newArgs, -1,
+                                                                               
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
+                                                                               
aggCall.getType(), aggCall.getName());
+                                       }
+                               } else {
+                                       newCall =
+                                                       
AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
+                                                                       
aggregate.getGroupSet().cardinality(),
+                                                                       
relBuilder.peek(), aggCall.getType(), aggCall.name);
+                               }
+                       }
+                       newTopAggCalls.set(i, newCall);
+               }
+               // Populate the group-by keys with the remapped arguments for 
aggregate A
+               newGroupSet.clear();
+               for (int arg : aggregate.getGroupSet()) {
+                       newGroupSet.add(sourceOf.get(arg));
+               }
+               relBuilder.push(
+                               aggregate.copy(aggregate.getTraitSet(),
+                                               relBuilder.build(), 
aggregate.indicator,
+                                               
ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
+               return relBuilder;
+       }
+       */
+
+       @SuppressWarnings("DanglingJavadoc")
+       private void rewriteUsingGroupingSets(RelOptRuleCall call,
+                                                                               
Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
+               final Set<ImmutableBitSet> groupSetTreeSet =
+                               new TreeSet<>(ImmutableBitSet.ORDERING);
+               groupSetTreeSet.add(aggregate.getGroupSet());
+               for (Pair<List<Integer>, Integer> argList : argLists) {
+                       groupSetTreeSet.add(
+                                       ImmutableBitSet.of(argList.left)
+                                                       .setIf(argList.right, 
argList.right >= 0)
+                                                       
.union(aggregate.getGroupSet()));
+               }
+
+               final ImmutableList<ImmutableBitSet> groupSets =
+                               ImmutableList.copyOf(groupSetTreeSet);
+               final ImmutableBitSet fullGroupSet = 
ImmutableBitSet.union(groupSets);
+
+               final List<AggregateCall> distinctAggCalls = new ArrayList<>();
+               for (Pair<AggregateCall, String> aggCall : 
aggregate.getNamedAggCalls()) {
+                       if (!aggCall.left.isDistinct()) {
+                               
distinctAggCalls.add(aggCall.left.rename(aggCall.right));
+                       }
+               }
+
+               final RelBuilder relBuilder = call.builder();
+               relBuilder.push(aggregate.getInput());
+               relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, 
groupSets.size() > 1, groupSets),
+                               distinctAggCalls);
+               final RelNode distinct = relBuilder.peek();
+               final int groupCount = fullGroupSet.cardinality();
+               final int indicatorCount = groupSets.size() > 1 ? groupCount : 
0;
+
+               final RelOptCluster cluster = aggregate.getCluster();
+               final RexBuilder rexBuilder = cluster.getRexBuilder();
+               final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+               final RelDataType booleanType =
+                               typeFactory.createTypeWithNullability(
+                                               
typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
+               final List<Pair<RexNode, String>> predicates = new 
ArrayList<>();
+               final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
+
+               /** Function to register a filter for a group set. */
+               class Registrar {
+                       RexNode group = null;
+
+                       private int register(ImmutableBitSet groupSet) {
+                               if (group == null) {
+                                       group = makeGroup(groupCount - 1);
+                               }
+                               final RexNode node =
+                                               
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group,
+                                                               
rexBuilder.makeExactLiteral(
+                                                                               
toNumber(remap(fullGroupSet, groupSet))));
+                               predicates.add(Pair.of(node, 
toString(groupSet)));
+                               return groupCount + indicatorCount + 
distinctAggCalls.size()
+                                               + predicates.size() - 1;
+                       }
+
+                       private RexNode makeGroup(int i) {
+                               final RexInputRef ref =
+                                               
rexBuilder.makeInputRef(booleanType, groupCount + i);
+                               final RexNode kase =
+                                               
rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref,
+                                                               
rexBuilder.makeExactLiteral(BigDecimal.ZERO),
+                                                               
rexBuilder.makeExactLiteral(TWO.pow(i)));
+                               if (i == 0) {
+                                       return kase;
+                               } else {
+                                       return 
rexBuilder.makeCall(SqlStdOperatorTable.PLUS,
+                                                       makeGroup(i - 1), kase);
+                               }
+                       }
+
+                       private BigDecimal toNumber(ImmutableBitSet bitSet) {
+                               BigDecimal n = BigDecimal.ZERO;
+                               for (int key : bitSet) {
+                                       n = n.add(TWO.pow(key));
+                               }
+                               return n;
+                       }
+
+                       private String toString(ImmutableBitSet bitSet) {
+                               final StringBuilder buf = new 
StringBuilder("$i");
+                               for (int key : bitSet) {
+                                       buf.append(key).append('_');
+                               }
+                               return buf.substring(0, buf.length() - 1);
+                       }
+               }
+               final Registrar registrar = new Registrar();
+               for (ImmutableBitSet groupSet : groupSets) {
+                       filters.put(groupSet, registrar.register(groupSet));
+               }
+
+               if (!predicates.isEmpty()) {
+                       List<Pair<RexNode, String>> nodes = new ArrayList<>();
+                       for (RelDataTypeField f : 
relBuilder.peek().getRowType().getFieldList()) {
+                               final RexNode node = 
rexBuilder.makeInputRef(f.getType(), f.getIndex());
+                               nodes.add(Pair.of(node, f.getName()));
+                       }
+                       nodes.addAll(predicates);
+                       relBuilder.project(Pair.left(nodes), Pair.right(nodes));
+               }
+
+               int x = groupCount + indicatorCount;
+               final List<AggregateCall> newCalls = new ArrayList<>();
+               for (AggregateCall aggCall : aggregate.getAggCallList()) {
+                       final int newFilterArg;
+                       final List<Integer> newArgList;
+                       final SqlAggFunction aggregation;
+                       if (!aggCall.isDistinct()) {
+                               aggregation = SqlStdOperatorTable.MIN;
+                               newArgList = ImmutableIntList.of(x++);
+                               newFilterArg = 
filters.get(aggregate.getGroupSet());
+                       } else {
+                               aggregation = aggCall.getAggregation();
+                               newArgList = remap(fullGroupSet, 
aggCall.getArgList());
+                               newFilterArg =
+                                               filters.get(
+                                                               
ImmutableBitSet.of(aggCall.getArgList())
+                                                                               
.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
+                                                                               
.union(aggregate.getGroupSet()));
+                       }
+                       final AggregateCall newCall =
+                                       AggregateCall.create(aggregation, 
false, newArgList, newFilterArg,
+                                                       
aggregate.getGroupCount(), distinct, null, aggCall.name);
+                       newCalls.add(newCall);
+               }
+
+               relBuilder.aggregate(
+                               relBuilder.groupKey(
+                                               remap(fullGroupSet, 
aggregate.getGroupSet()),
+                                               aggregate.indicator,
+                                               remap(fullGroupSet, 
aggregate.getGroupSets())),
+                               newCalls);
+               relBuilder.convert(aggregate.getRowType(), true);
+               call.transformTo(relBuilder.build());
+       }
+
+       private static ImmutableBitSet remap(ImmutableBitSet groupSet,
+                                                                               
ImmutableBitSet bitSet) {
+               final ImmutableBitSet.Builder builder = 
ImmutableBitSet.builder();
+               for (Integer bit : bitSet) {
+                       builder.set(remap(groupSet, bit));
+               }
+               return builder.build();
+       }
+
+       private static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet 
groupSet,
+                                                                               
                                Iterable<ImmutableBitSet> bitSets) {
+               final ImmutableList.Builder<ImmutableBitSet> builder =
+                               ImmutableList.builder();
+               for (ImmutableBitSet bitSet : bitSets) {
+                       builder.add(remap(groupSet, bitSet));
+               }
+               return builder.build();
+       }
+
+       private static List<Integer> remap(ImmutableBitSet groupSet,
+                                                                       
List<Integer> argList) {
+               ImmutableIntList list = ImmutableIntList.of();
+               for (int arg : argList) {
+                       list = list.append(remap(groupSet, arg));
+               }
+               return list;
+       }
+
+       private static int remap(ImmutableBitSet groupSet, int arg) {
+               return arg < 0 ? -1 : groupSet.indexOf(arg);
+       }
+
+       /**
+        * Converts an aggregate relational expression that contains just one
+        * distinct aggregate function (or perhaps several over the same 
arguments)
+        * and no non-distinct aggregate functions.
+        */
+       private RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate 
aggregate,
+                                                                       
List<Integer> argList, int filterArg) {
+               // For example,
+               //      SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
+               //      FROM emp
+               //      GROUP BY deptno
+               //
+               // becomes
+               //
+               //      SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
+               //      FROM (
+               //        SELECT DISTINCT deptno, sal AS distinct_sal
+               //        FROM EMP GROUP BY deptno)
+               //      GROUP BY deptno
+
+               // Project the columns of the GROUP BY plus the arguments
+               // to the agg function.
+               final Map<Integer, Integer> sourceOf = new HashMap<>();
+               createSelectDistinct(relBuilder, aggregate, argList, filterArg, 
sourceOf);
+
+               // Create an aggregate on top, with the new aggregate list.
+               final List<AggregateCall> newAggCalls =
+                               Lists.newArrayList(aggregate.getAggCallList());
+               rewriteAggCalls(newAggCalls, argList, sourceOf);
+               final int cardinality = aggregate.getGroupSet().cardinality();
+               relBuilder.push(
+                               aggregate.copy(aggregate.getTraitSet(), 
relBuilder.build(),
+                                               aggregate.indicator, 
ImmutableBitSet.range(cardinality), null,
+                                               newAggCalls));
+               return relBuilder;
+       }
+
+       /**
+        * Converts all distinct aggregate calls to a given set of arguments.
+        *
+        * <p>This method is called several times, one for each set of 
arguments.
+        * Each time it is called, it generates a JOIN to a new SELECT DISTINCT
+        * relational expression, and modifies the set of top-level calls.
+        *
+        * @param aggregate Original aggregate
+        * @param n              Ordinal of this in a join. {@code relBuilder} 
contains the
+        *                                input relational expression (either 
the original
+        *                                aggregate, the output from the 
previous call to this
+        *                                method. {@code n} is 0 if we're 
converting the
+        *                                first distinct aggregate in a query 
with no non-distinct
+        *                                aggregates)
+        * @param argList   Arguments to the distinct aggregate function
+        * @param filterArg Argument that filters input to aggregate function, 
or -1
+        * @param refs    Array of expressions which will be the projected by 
the
+        *                                result of this rule. Those relating 
to this arg list will
+        *                                be modified  @return Relational 
expression
+        */
+       private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int 
n,
+                                               List<Integer> argList, int 
filterArg, List<RexInputRef> refs) {
+               final RexBuilder rexBuilder = 
aggregate.getCluster().getRexBuilder();
+               final List<RelDataTypeField> leftFields;
+               if (n == 0) {
+                       leftFields = null;
+               } else {
+                       leftFields = 
relBuilder.peek().getRowType().getFieldList();
+               }
+
+               // LogicalAggregate(
+               //       child,
+               //       {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
+               //
+               // becomes
+               //
+               // LogicalAggregate(
+               //       LogicalJoin(
+               //               child,
+               //               LogicalAggregate(child, < all columns > {}),
+               //               INNER,
+               //               <f2 = f5>))
+               //
+               // E.g.
+               //   SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), 
MAX(age)
+               //   FROM Emps
+               //   GROUP BY deptno
+               //
+               // becomes
+               //
+               //   SELECT e.deptno, adsal.sum_sal, adgender.count_gender, 
e.max_age
+               //   FROM (
+               //       SELECT deptno, MAX(age) as max_age
+               //       FROM Emps GROUP BY deptno) AS e
+               //   JOIN (
+               //       SELECT deptno, COUNT(gender) AS count_gender FROM (
+               //         SELECT DISTINCT deptno, gender FROM Emps) AS dgender
+               //       GROUP BY deptno) AS adgender
+               //       ON e.deptno = adgender.deptno
+               //   JOIN (
+               //       SELECT deptno, SUM(sal) AS sum_sal FROM (
+               //         SELECT DISTINCT deptno, sal FROM Emps) AS dsal
+               //       GROUP BY deptno) AS adsal
+               //   ON e.deptno = adsal.deptno
+               //   GROUP BY e.deptno
+               //
+               // Note that if a query contains no non-distinct aggregates, 
then the
+               // very first join/group by is omitted.  In the example above, 
if
+               // MAX(age) is removed, then the sub-select of "e" is not 
needed, and
+               // instead the two other group by's are joined to one another.
+
+               // Project the columns of the GROUP BY plus the arguments
+               // to the agg function.
+               final Map<Integer, Integer> sourceOf = new HashMap<>();
+               createSelectDistinct(relBuilder, aggregate, argList, filterArg, 
sourceOf);
+
+               // Now compute the aggregate functions on top of the distinct 
dataset.
+               // Each distinct agg becomes a non-distinct call to the 
corresponding
+               // field from the right; for example,
+               //   "COUNT(DISTINCT e.sal)"
+               // becomes
+               //   "COUNT(distinct_e.sal)".
+               final List<AggregateCall> aggCallList = new ArrayList<>();
+               final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+               final int groupAndIndicatorCount =
+                               aggregate.getGroupCount() + 
aggregate.getIndicatorCount();
+               int i = groupAndIndicatorCount - 1;
+               for (AggregateCall aggCall : aggCalls) {
+                       ++i;
+
+                       // Ignore agg calls which are not distinct or have the 
wrong set
+                       // arguments. If we're rewriting aggs whose args are 
{sal}, we will
+                       // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) 
but ignore
+                       // COUNT(DISTINCT gender) or SUM(sal).
+                       if (!aggCall.isDistinct()) {
+                               continue;
+                       }
+                       if (!aggCall.getArgList().equals(argList)) {
+                               continue;
+                       }
+
+                       // Re-map arguments.
+                       final int argCount = aggCall.getArgList().size();
+                       final List<Integer> newArgs = new ArrayList<>(argCount);
+                       for (int j = 0; j < argCount; j++) {
+                               final Integer arg = aggCall.getArgList().get(j);
+                               newArgs.add(sourceOf.get(arg));
+                       }
+                       final int newFilterArg =
+                                       aggCall.filterArg >= 0 ? 
sourceOf.get(aggCall.filterArg) : -1;
+                       final AggregateCall newAggCall =
+                                       
AggregateCall.create(aggCall.getAggregation(), false, newArgs,
+                                                       newFilterArg, 
aggCall.getType(), aggCall.getName());
+                       assert refs.get(i) == null;
+                       if (n == 0) {
+                               refs.set(i,
+                                               new 
RexInputRef(groupAndIndicatorCount + aggCallList.size(),
+                                                               
newAggCall.getType()));
+                       } else {
+                               refs.set(i,
+                                               new 
RexInputRef(leftFields.size() + groupAndIndicatorCount
+                                                               + 
aggCallList.size(), newAggCall.getType()));
+                       }
+                       aggCallList.add(newAggCall);
+               }
+
+               final Map<Integer, Integer> map = new HashMap<>();
+               for (Integer key : aggregate.getGroupSet()) {
+                       map.put(key, map.size());
+               }
+               final ImmutableBitSet newGroupSet = 
aggregate.getGroupSet().permute(map);
+               assert newGroupSet
+                               
.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
+               ImmutableList<ImmutableBitSet> newGroupingSets = null;
+               if (aggregate.indicator) {
+                       newGroupingSets =
+                                       
ImmutableBitSet.ORDERING.immutableSortedCopy(
+                                                       
ImmutableBitSet.permute(aggregate.getGroupSets(), map));
+               }
+
+               relBuilder.push(
+                               aggregate.copy(aggregate.getTraitSet(), 
relBuilder.build(),
+                                               aggregate.indicator, 
newGroupSet, newGroupingSets, aggCallList));
+
+               // If there's no left child yet, no need to create the join
+               if (n == 0) {
+                       return;
+               }
+
+               // Create the join condition. It is of the form
+               //  'left.f0 = right.f0 and left.f1 = right.f1 and ...'
+               // where {f0, f1, ...} are the GROUP BY fields.
+               final List<RelDataTypeField> distinctFields =
+                               relBuilder.peek().getRowType().getFieldList();
+               final List<RexNode> conditions = Lists.newArrayList();
+               for (i = 0; i < groupAndIndicatorCount; ++i) {
+                       // null values form its own group
+                       // use "is not distinct from" so that the join condition
+                       // allows null values to match.
+                       conditions.add(
+                                       
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
+                                                       RexInputRef.of(i, 
leftFields),
+                                                       new 
RexInputRef(leftFields.size() + i,
+                                                                       
distinctFields.get(i).getType())));
+               }
+
+               // Join in the new 'select distinct' relation.
+               relBuilder.join(JoinRelType.INNER, conditions);
+       }
+
+       private static void rewriteAggCalls(
+                       List<AggregateCall> newAggCalls,
+                       List<Integer> argList,
+                       Map<Integer, Integer> sourceOf) {
+               // Rewrite the agg calls. Each distinct agg becomes a 
non-distinct call
+               // to the corresponding field from the right; for example,
+               // "COUNT(DISTINCT e.sal)" becomes   "COUNT(distinct_e.sal)".
+               for (int i = 0; i < newAggCalls.size(); i++) {
+                       final AggregateCall aggCall = newAggCalls.get(i);
+
+                       // Ignore agg calls which are not distinct or have the 
wrong set
+                       // arguments. If we're rewriting aggregates whose args 
are {sal}, we will
+                       // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) 
but ignore
+                       // COUNT(DISTINCT gender) or SUM(sal).
+                       if (!aggCall.isDistinct()) {
+                               continue;
+                       }
+                       if (!aggCall.getArgList().equals(argList)) {
+                               continue;
+                       }
+
+                       // Re-map arguments.
+                       final int argCount = aggCall.getArgList().size();
+                       final List<Integer> newArgs = new ArrayList<>(argCount);
+                       for (int j = 0; j < argCount; j++) {
+                               final Integer arg = aggCall.getArgList().get(j);
+                               newArgs.add(sourceOf.get(arg));
+                       }
+                       final AggregateCall newAggCall =
+                                       
AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
+                                                       aggCall.getType(), 
aggCall.getName());
+                       newAggCalls.set(i, newAggCall);
+               }
+       }
+
+       /**
+        * Given an {@link org.apache.calcite.rel.logical.LogicalAggregate}
+        * and the ordinals of the arguments to a
+        * particular call to an aggregate function, creates a 'select distinct'
+        * relational expression which projects the group columns and those
+        * arguments but nothing else.
+        *
+        * <p>For example, given
+        *
+        * <blockquote>
+        * <pre>select f0, count(distinct f1), count(distinct f2)
+        * from t group by f0</pre>
+        * </blockquote>
+        *
+        * and the argument list
+        *
+        * <blockquote>{2}</blockquote>
+        *
+        * returns
+        *
+        * <blockquote>
+        * <pre>select distinct f0, f2 from t</pre>
+        * </blockquote>
+        *
+        * '
+        *
+        * <p>The <code>sourceOf</code> map is populated with the source of each
+        * column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 
2.</p>
+        *
+        * @param relBuilder Relational expression builder
+        * @param aggregate Aggregate relational expression
+        * @param argList   Ordinals of columns to make distinct
+        * @param filterArg Ordinal of column to filter on, or -1
+        * @param sourceOf  Out parameter, is populated with a map of where each
+        *                                output field came from
+        * @return Aggregate relational expression which projects the required
+        * columns
+        */
+       private RelBuilder createSelectDistinct(RelBuilder relBuilder,
+                                                                               
        Aggregate aggregate, List<Integer> argList, int filterArg,
+                                                                               
        Map<Integer, Integer> sourceOf) {
+               relBuilder.push(aggregate.getInput());
+               final List<Pair<RexNode, String>> projects = new ArrayList<>();
+               final List<RelDataTypeField> childFields =
+                               relBuilder.peek().getRowType().getFieldList();
+               for (int i : aggregate.getGroupSet()) {
+                       sourceOf.put(i, projects.size());
+                       projects.add(RexInputRef.of2(i, childFields));
+               }
+               for (Integer arg : argList) {
+                       if (filterArg >= 0) {
+                               // Implement
+                               //   agg(DISTINCT arg) FILTER $f
+                               // by generating
+                               //   SELECT DISTINCT ... CASE WHEN $f THEN arg 
ELSE NULL END AS arg
+                               // and then applying
+                               //   agg(arg)
+                               // as usual.
+                               //
+                               // It works except for (rare) agg functions 
that need to see null
+                               // values.
+                               final RexBuilder rexBuilder = 
aggregate.getCluster().getRexBuilder();
+                               final RexInputRef filterRef = 
RexInputRef.of(filterArg, childFields);
+                               final Pair<RexNode, String> argRef = 
RexInputRef.of2(arg, childFields);
+                               RexNode condition =
+                                               
rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef,
+                                                               argRef.left,
+                                                               
rexBuilder.ensureType(argRef.left.getType(),
+                                                                               
rexBuilder.constantNull(), true));
+                               sourceOf.put(arg, projects.size());
+                               projects.add(Pair.of(condition, "i$" + 
argRef.right));
+                               continue;
+                       }
+                       if (sourceOf.get(arg) != null) {
+                               continue;
+                       }
+                       sourceOf.put(arg, projects.size());
+                       projects.add(RexInputRef.of2(arg, childFields));
+               }
+               relBuilder.project(Pair.left(projects), Pair.right(projects));
+
+               // Get the distinct values of the GROUP BY fields and the 
arguments
+               // to the agg functions.
+               relBuilder.push(
+                               aggregate.copy(aggregate.getTraitSet(), 
relBuilder.build(), false,
+                                               
ImmutableBitSet.range(projects.size()),
+                                               null, 
ImmutableList.<AggregateCall>of()));
+               return relBuilder;
+       }
+}
+
+// End AggregateExpandDistinctAggregatesRule.java

http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index f9c8d8d..8f16d32 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.rules
 
 import org.apache.calcite.rel.rules._
 import org.apache.calcite.tools.{RuleSet, RuleSets}
-import org.apache.flink.table.calcite.rules.FlinkAggregateJoinTransposeRule
+import 
org.apache.flink.table.calcite.rules.{FlinkAggregateExpandDistinctAggregatesRule,
 FlinkAggregateJoinTransposeRule}
 import org.apache.flink.table.plan.rules.dataSet._
 import org.apache.flink.table.plan.rules.datastream._
 import org.apache.flink.table.plan.rules.datastream.{DataStreamCalcRule, 
DataStreamScanRule, DataStreamUnionRule}
@@ -102,6 +102,9 @@ object FlinkRuleSets {
     ProjectToCalcRule.INSTANCE,
     CalcMergeRule.INSTANCE,
 
+    // distinct aggregate rule for FLINK-3475
+    FlinkAggregateExpandDistinctAggregatesRule.JOIN,
+
     // translate to Flink DataSet nodes
     DataSetWindowAggregateRule.INSTANCE,
     DataSetAggregateRule.INSTANCE,

http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
index d1f932e..9c0acdd 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
@@ -44,9 +44,6 @@ class DataSetAggregateRule
 
     // check if we have distinct aggregates
     val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
-    if (distinctAggs) {
-      throw TableException("DISTINCT aggregates are currently not supported.")
-    }
 
     !distinctAggs
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
index e8084fa..aa977b1 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
@@ -50,9 +50,6 @@ class DataSetAggregateWithNullValuesRule
 
     // check if we have distinct aggregates
     val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
-    if (distinctAggs) {
-      throw TableException("DISTINCT aggregates are currently not supported.")
-    }
 
     !distinctAggs
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
index d7e429c..a60cfaa 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
@@ -213,7 +213,7 @@ class AggregationsITCase(
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
-  @Test(expected = classOf[TableException])
+  @Test
   def testDistinctAggregate(): Unit = {
 
     val env = ExecutionEnvironment.getExecutionEnvironment
@@ -221,14 +221,17 @@ class AggregationsITCase(
 
     val sqlQuery = "SELECT sum(_1) as a, count(distinct _3) as b FROM MyTable"
 
-    val ds = CollectionDataSets.get3TupleDataSet(env)
-    tEnv.registerDataSet("MyTable", ds)
+    val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv)
+    tEnv.registerTable("MyTable", ds)
 
-    // must fail. distinct aggregates are not supported
-    tEnv.sql(sqlQuery).toDataSet[Row]
+    val result = tEnv.sql(sqlQuery)
+
+    val expected = "231,21"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
-  @Test(expected = classOf[TableException])
+  @Test
   def testGroupedDistinctAggregate(): Unit = {
 
     val env = ExecutionEnvironment.getExecutionEnvironment
@@ -236,11 +239,15 @@ class AggregationsITCase(
 
     val sqlQuery = "SELECT _2, avg(distinct _1) as a, count(_3) as b FROM 
MyTable GROUP BY _2"
 
-    val ds = CollectionDataSets.get3TupleDataSet(env)
-    tEnv.registerDataSet("MyTable", ds)
+    val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv)
+    tEnv.registerTable("MyTable", ds)
 
-    // must fail. distinct aggregates are not supported
-    tEnv.sql(sqlQuery).toDataSet[Row]
+    val result = tEnv.sql(sqlQuery)
+
+    val expected =
+      "6,18,6\n5,13,5\n4,8,4\n3,5,3\n2,2,2\n1,1,1"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala
new file mode 100644
index 0000000..38e4ea8
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.scala.batch.sql
+
+import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils.TableTestUtil._
+import org.junit.Test
+
+class DistinctAggregateTest extends TableTestBase {
+
+  @Test
+  def testSingleDistinctAggregate(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT COUNT(DISTINCT a) FROM MyTable"
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      unaryNode(
+        "DataSetUnion",
+        unaryNode(
+          "DataSetValues",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetCalc",
+              batchTableNode(0),
+              term("select", "a")
+            ),
+            term("groupBy", "a"),
+            term("select", "a")
+          ),
+          tuples(List(null)),
+          term("values", "a")
+        ),
+        term("union", "a")
+      ),
+      term("select", "COUNT(a) AS EXPR$0")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testMultiDistinctAggregateOnSameColumn(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT a), MAX(DISTINCT a) 
FROM MyTable"
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      unaryNode(
+        "DataSetUnion",
+        unaryNode(
+          "DataSetValues",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetCalc",
+              batchTableNode(0),
+              term("select", "a")
+            ),
+            term("groupBy", "a"),
+            term("select", "a")
+          ),
+          tuples(List(null)),
+          term("values", "a")
+        ),
+        term("union", "a")
+      ),
+      term("select", "COUNT(a) AS EXPR$0", "SUM(a) AS EXPR$1", "MAX(a) AS 
EXPR$2")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSingleDistinctAggregateAndOneOrMultiNonDistinctAggregate(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    // case 0x00: DISTINCT on COUNT and Non-DISTINCT on others
+    val sqlQuery0 = "SELECT COUNT(DISTINCT a), SUM(b) FROM MyTable"
+
+    val expected0 = unaryNode(
+      "DataSetAggregate",
+      unaryNode(
+        "DataSetUnion",
+        unaryNode(
+          "DataSetValues",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetCalc",
+              batchTableNode(0),
+              term("select", "a", "b")
+            ),
+            term("groupBy", "a"),
+            term("select", "a", "SUM(b) AS EXPR$1")
+          ),
+          tuples(List(null, null)),
+          term("values", "a", "EXPR$1")
+        ),
+        term("union", "a", "EXPR$1")
+      ),
+      term("select", "COUNT(a) AS EXPR$0", "SUM(EXPR$1) AS EXPR$1")
+    )
+
+    util.verifySql(sqlQuery0, expected0)
+
+    // case 0x01: Non-DISTINCT on COUNT and DISTINCT on others
+    val sqlQuery1 = "SELECT COUNT(a), SUM(DISTINCT b) FROM MyTable"
+
+    val expected1 = unaryNode(
+      "DataSetAggregate",
+      unaryNode(
+        "DataSetUnion",
+        unaryNode(
+          "DataSetValues",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetCalc",
+              batchTableNode(0),
+              term("select", "a", "b")
+            ),
+            term("groupBy", "b"),
+            term("select", "b", "COUNT(a) AS EXPR$0")
+          ),
+          tuples(List(null, null)),
+          term("values", "b", "EXPR$0")
+        ),
+        term("union", "b", "EXPR$0")
+      ),
+      term("select", "$SUM0(EXPR$0) AS EXPR$0", "SUM(b) AS EXPR$1")
+    )
+
+    util.verifySql(sqlQuery1, expected1)
+  }
+
+  @Test
+  def testMultiDistinctAggregateOnDifferentColumn(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT b) FROM MyTable"
+
+    val expected = binaryNode(
+      "DataSetSingleRowJoin",
+      unaryNode(
+        "DataSetAggregate",
+        unaryNode(
+          "DataSetUnion",
+          unaryNode(
+            "DataSetValues",
+            unaryNode(
+              "DataSetAggregate",
+              unaryNode(
+                "DataSetCalc",
+                batchTableNode(0),
+                term("select", "a")
+              ),
+              term("groupBy", "a"),
+              term("select", "a")
+            ),
+            tuples(List(null)),
+            term("values", "a")
+          ),
+          term("union", "a")
+        ),
+        term("select", "COUNT(a) AS EXPR$0")
+      ),
+      unaryNode(
+        "DataSetAggregate",
+        unaryNode(
+          "DataSetUnion",
+          unaryNode(
+            "DataSetValues",
+            unaryNode(
+              "DataSetAggregate",
+              unaryNode(
+                "DataSetCalc",
+                batchTableNode(0),
+                term("select", "b")
+              ),
+              term("groupBy", "b"),
+              term("select", "b")
+            ),
+            tuples(List(null)),
+            term("values", "b")
+          ),
+          term("union", "b")
+        ),
+        term("select", "SUM(b) AS EXPR$1")
+      ),
+      term("where", "true"),
+      term("join", "EXPR$0", "EXPR$1"),
+      term("joinType", "NestedLoopJoin")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testMultiDistinctAndNonDistinctAggregateOnDifferentColumn(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT b), COUNT(c) FROM 
MyTable"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      binaryNode(
+        "DataSetSingleRowJoin",
+        binaryNode(
+          "DataSetSingleRowJoin",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetUnion",
+              unaryNode(
+                "DataSetValues",
+                batchTableNode(0),
+                tuples(List(null, null, null)),
+                term("values", "a, b, c")
+              ),
+              term("union", "a, b, c")
+            ),
+            term("select", "COUNT(c) AS EXPR$2")
+          ),
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetUnion",
+              unaryNode(
+                "DataSetValues",
+                unaryNode(
+                  "DataSetAggregate",
+                  unaryNode(
+                    "DataSetCalc",
+                    batchTableNode(0),
+                    term("select", "a")
+                  ),
+                  term("groupBy", "a"),
+                  term("select", "a")
+                ),
+                tuples(List(null)),
+                term("values", "a")
+              ),
+              term("union", "a")
+            ),
+            term("select", "COUNT(a) AS EXPR$0")
+          ),
+          term("where", "true"),
+          term("join", "EXPR$2, EXPR$0"),
+          term("joinType", "NestedLoopJoin")
+        ),
+        unaryNode(
+          "DataSetAggregate",
+          unaryNode(
+            "DataSetUnion",
+            unaryNode(
+              "DataSetValues",
+              unaryNode(
+                "DataSetAggregate",
+                unaryNode(
+                  "DataSetCalc",
+                  batchTableNode(0),
+                  term("select", "b")
+                ),
+                term("groupBy", "b"),
+                term("select", "b")
+              ),
+              tuples(List(null)),
+              term("values", "b")
+            ),
+            term("union", "b")
+          ),
+          term("select", "SUM(b) AS EXPR$1")
+        ),
+        term("where", "true"),
+        term("join", "EXPR$2", "EXPR$0, EXPR$1"),
+        term("joinType", "NestedLoopJoin")
+      ),
+      term("select", "EXPR$0, EXPR$1, EXPR$2")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSingleDistinctAggregateWithGrouping(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT a, COUNT(a), SUM(DISTINCT b) FROM MyTable GROUP BY 
a"
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      unaryNode(
+        "DataSetAggregate",
+        unaryNode(
+          "DataSetCalc",
+          batchTableNode(0),
+          term("select", "a", "b")
+        ),
+        term("groupBy", "a", "b"),
+        term("select", "a", "b", "COUNT(a) AS EXPR$1")
+      ),
+      term("groupBy", "a"),
+      term("select", "a", "SUM(EXPR$1) AS EXPR$1", "SUM(b) AS EXPR$2")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSingleDistinctAggregateWithGroupingAndCountStar(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b) FROM MyTable GROUP BY 
a"
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      unaryNode(
+        "DataSetAggregate",
+        unaryNode(
+          "DataSetCalc",
+          batchTableNode(0),
+          term("select", "a", "b")
+        ),
+        term("groupBy", "a", "b"),
+        term("select", "a", "b", "COUNT(*) AS EXPR$1")
+      ),
+      term("groupBy", "a"),
+      term("select", "a", "SUM(EXPR$1) AS EXPR$1", "SUM(b) AS EXPR$2")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testTwoDistinctAggregateWithGroupingAndCountStar(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b), COUNT(DISTINCT b) 
FROM MyTable GROUP BY a"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      binaryNode(
+        "DataSetJoin",
+        unaryNode(
+          "DataSetAggregate",
+          unaryNode(
+            "DataSetCalc",
+            batchTableNode(0),
+            term("select", "a", "b")
+          ),
+          term("groupBy", "a"),
+          term("select", "a", "COUNT(*) AS EXPR$1")
+        ),
+        unaryNode(
+          "DataSetAggregate",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetCalc",
+              batchTableNode(0),
+              term("select", "a", "b")
+            ),
+            term("groupBy", "a, b"),
+            term("select", "a, b")
+          ),
+          term("groupBy", "a"),
+          term("select", "a, SUM(b) AS EXPR$2, COUNT(b) AS EXPR$3")
+        ),
+        term("where", "IS NOT DISTINCT FROM(a, a0)"),
+        term("join", "a, EXPR$1, a0, EXPR$2, EXPR$3"),
+        term("joinType", "InnerJoin")
+      ),
+      term("select", "a, EXPR$1, EXPR$2, EXPR$3")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testTwoDifferentDistinctAggregateWithGroupingAndCountStar(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b), COUNT(DISTINCT c) 
FROM MyTable GROUP BY a"
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      binaryNode(
+        "DataSetJoin",
+        unaryNode(
+          "DataSetCalc",
+          binaryNode(
+            "DataSetJoin",
+            unaryNode(
+              "DataSetAggregate",
+              batchTableNode(0),
+              term("groupBy", "a"),
+              term("select", "a, COUNT(*) AS EXPR$1")
+            ),
+            unaryNode(
+              "DataSetAggregate",
+              unaryNode(
+                "DataSetAggregate",
+                unaryNode(
+                  "DataSetCalc",
+                  batchTableNode(0),
+                  term("select", "a", "b")
+                ),
+                term("groupBy", "a, b"),
+                term("select", "a, b")
+              ),
+              term("groupBy", "a"),
+              term("select", "a, SUM(b) AS EXPR$2")
+            ),
+            term("where", "IS NOT DISTINCT FROM(a, a0)"),
+            term("join", "a, EXPR$1, a0, EXPR$2"),
+            term("joinType", "InnerJoin")
+          ),
+          term("select", "a, EXPR$1, EXPR$2")
+        ),
+        unaryNode(
+          "DataSetAggregate",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetCalc",
+              batchTableNode(0),
+              term("select", "a", "c")
+            ),
+            term("groupBy", "a, c"),
+            term("select", "a, c")
+          ),
+          term("groupBy", "a"),
+          term("select", "a, COUNT(c) AS EXPR$3")
+        ),
+        term("where", "IS NOT DISTINCT FROM(a, a0)"),
+        term("join", "a, EXPR$1, EXPR$2, a0, EXPR$3"),
+        term("joinType", "InnerJoin")
+      ),
+      term("select", "a, EXPR$1, EXPR$2, EXPR$3")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/36c9348f/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
index abf71e2..516fcd2 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
@@ -85,7 +85,7 @@ class QueryDecorrelationTest extends TableTestBase {
           term("join", "empno", "salary", "empno0"),
           term("joinType", "InnerJoin")
         ),
-        term("select", "salary", "empno0")
+        term("select", "empno0", "salary")
       ),
       term("groupBy", "empno0"),
       term("select", "empno0", "AVG(salary) AS EXPR$0")

Reply via email to