Github user fhueske commented on a diff in the pull request:
https://github.com/apache/flink/pull/4873#discussion_r147142386
--- Diff:
flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/DecomposeGroupingSetRule.scala
---
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical
+
+import org.apache.calcite.plan.RelOptRule._
+import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall}
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.rel.logical._
+import org.apache.calcite.rex.RexNode
+import org.apache.calcite.sql.SqlKind
+import org.apache.calcite.tools.RelBuilder
+import org.apache.calcite.util.ImmutableBitSet
+
+import scala.collection.JavaConversions._
+
+class DecomposeGroupingSetRule
+ extends RelOptRule(
+ operand(classOf[LogicalAggregate], any),
+ "DecomposeGroupingSetRule") {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
+ !agg.getGroupSets.isEmpty &&
+
DecomposeGroupingSetRule.getGroupIdExprIndexes(agg.getAggCallList).nonEmpty
+ }
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
+ val groupIdExprs =
DecomposeGroupingSetRule.getGroupIdExprIndexes(agg.getAggCallList).toSet
+
+ val subAggs = agg.groupSets.map(set =>
+ DecomposeGroupingSetRule.decompose(call.builder(), agg,
groupIdExprs, set))
+
+ val union = subAggs.reduce((l, r) => new LogicalUnion(
+ agg.getCluster,
+ agg.getTraitSet,
+ Seq(l, r),
+ true
+ ))
+ call.transformTo(union)
+ }
+}
+
+object DecomposeGroupingSetRule {
+ val INSTANCE = new DecomposeGroupingSetRule
+
+ private def getGroupIdExprIndexes(aggCalls: Seq[AggregateCall]) = {
+ aggCalls.zipWithIndex.filter { case (call, _) =>
+ call.getAggregation.getKind match {
+ case SqlKind.GROUP_ID | SqlKind.GROUPING | SqlKind.GROUPING_ID =>
+ true
+ case _ =>
+ false
+ }
+ }.map { case (_, idx) => idx}
+ }
+
+ private def decompose(
+ relBuilder: RelBuilder,
+ agg: LogicalAggregate,
+ groupExprIndexes : Set[Int],
+ groupSet: ImmutableBitSet
+ ) = {
+ val aggsWithIndexes = agg.getAggCallList.zipWithIndex
+ val subAgg = new LogicalAggregate(
+ agg.getCluster,
+ agg.getTraitSet,
+ agg.getInput,
+ false,
+ groupSet,
+ Seq(),
+ aggsWithIndexes
+ .filter { case (_, idx) => !groupExprIndexes.contains(idx) }
+ .map { case (call, _) => call}
+ )
+
+ val rexBuilder = relBuilder.getRexBuilder
+ relBuilder.push(subAgg)
+
+ val groupingFields = new Array[RexNode](agg.getGroupCount)
+ val groupingFieldsName = Seq.range(0, agg.getGroupCount).map(
+ x => agg.getRowType.getFieldNames.get(x)
+ )
+ Seq.range(0, agg.getGroupCount).foreach(x =>
+ groupingFields(x) = rexBuilder.makeNullLiteral(
+ agg.getRowType.getFieldList.get(x).getType)
+ )
+
+ groupSet.toList.zipWithIndex.foreach { case (group, idx) =>
+ groupingFields(group) = rexBuilder.makeInputRef(relBuilder.peek(),
idx)
+ }
+
+ val aggFields = aggsWithIndexes.map { case (call, idx) =>
+ if (groupExprIndexes.contains(idx)) {
+ lowerGroupExpr(agg.getCluster, call, groupSet)
+ } else {
+ rexBuilder.makeInputRef(subAgg, idx + subAgg.getGroupCount)
--- End diff --
this will break if there is a group id expression with a smaller index than
a aggregation function.
For example when we have `SELECT SUM(x), GROUP_ID(), AVG(y)`, `AVG(y)` will
have the index 2 because it's at pos 3 but will be a position `groupCnt + 1` in
the result of the sub aggregation.
---