Github user liancheng commented on a diff in the pull request:
https://github.com/apache/spark/pull/11283#discussion_r55016431
--- Diff:
sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala ---
@@ -233,6 +236,82 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext:
SQLContext) extends Loggi
)
}
+ private def sameOutput(output1: Seq[Attribute], output2:
Seq[Attribute]): Boolean =
+ output1.size == output2.size &&
+ output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
+
+ private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean
= {
+ assert(a.child == e && e.child == p)
+ a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
+ sameOutput(e.output, p.child.output ++
a.groupingExpressions.map(_.asInstanceOf[Attribute]))
+ }
+
+ private def groupingSetToSQL(
+ agg: Aggregate,
+ expand: Expand,
+ project: Project): String = {
+ assert(agg.groupingExpressions.length > 1)
+
+ // The last column of Expand is always grouping ID
+ val gid = expand.output.last
+
+ val numOriginalOutput = project.child.output.length
+ // Assumption: Aggregate's groupingExpressions is composed of
+ // 1) the attributes of aliased group by expressions
+ // 2) gid, which is always the last one
+ val groupByAttributes =
agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
+ // Assumption: Project's projectList is composed of
+ // 1) the original output (Project's child.output),
+ // 2) the aliased group by expressions.
+ val groupByExprs =
project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
+ val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
+
+ // a map from group by attributes to the original group by expressions.
+ val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
+
+ val groupingSet = expand.projections.map { project =>
+ // Assumption: expand.projections is composed of
+ // 1) the original output (Project's child.output),
+ // 2) group by attributes(or null literal)
+ // 3) gid, which is always the last one in each project in Expand
+ project.drop(numOriginalOutput).dropRight(1).collect {
+ case attr: Attribute if groupByAttrMap.contains(attr) =>
groupByAttrMap(attr)
+ }
+ }
+ val groupingSetSQL =
+ "GROUPING SETS(" +
+ groupingSet.map(e => s"(${e.map(_.sql).mkString(",
")})").mkString(", ") + ")"
+
+ val aggExprs = agg.aggregateExpressions.map { case expr =>
+ expr.transformDown {
+ // grouping_id() is converted to VirtualColumn.groupingIdName by
Analyzer. Revert it back.
+ case ar: AttributeReference if ar == gid => GroupingID(Nil)
+ case ar: AttributeReference if groupByAttrMap.contains(ar) =>
groupByAttrMap(ar)
+ case a @ Cast(BitwiseAnd(
+ ShiftRight(ar: AttributeReference, _ @ Literal(value: Any,
IntegerType)),
+ Literal(1, IntegerType)), ByteType) if ar == gid =>
+ // for converting an expression to its original SQL format
grouping(col)
+ val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
+ val groupingCol = groupByExprs.lift(idx)
+ if (groupingCol.isDefined) {
+ Grouping(groupingCol.get)
+ } else {
+ throw new UnsupportedOperationException(s"unsupported operator
$a")
+ }
+ }
+ }
+
+ build(
+ "SELECT",
+ aggExprs.map(_.sql).mkString(", "),
+ if (agg.child == OneRowRelation) "" else "FROM",
+ toSQL(project.child),
--- End diff --
Let's add some test cases where `project.child` itself is a more
complicated plan.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]