Github user yhuai commented on a diff in the pull request:
https://github.com/apache/spark/pull/11583#discussion_r61764184
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
---
@@ -363,43 +363,68 @@ class Analyzer(
object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p: Pivot if !p.childrenResolved |
!p.aggregates.forall(_.resolved) => p
+ case p: Pivot if !p.childrenResolved |
!p.aggregates.forall(_.resolved)
+ | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates,
child) =>
val singleAgg = aggregates.size == 1
- val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap {
value =>
- def ifExpr(expr: Expression) = {
- If(EqualTo(pivotColumn, value), expr, Literal(null))
+ def outputName(value: Literal, aggregate: Expression): String = {
+ if (singleAgg) value.toString else value + "_" + aggregate.sql
+ }
+ if (aggregates.forall(a =>
PivotFirst.supportsDataType(a.dataType))) {
+ // Since evaluating |pivotValues| if statements for each input
row can get slow this is an
+ // alternate plan that instead uses two steps of aggregation.
+ val namedAggExps: Seq[NamedExpression] = aggregates.map(a =>
Alias(a, a.sql)())
+ val namedPivotCol = pivotColumn match {
+ case n: NamedExpression => n
+ case _ => Alias(pivotColumn, "__pivot_col")()
+ }
+ val bigGroup = groupByExprs :+ namedPivotCol
+ val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps,
child)
+ val castPivotValues = pivotValues.map(Cast(_,
pivotColumn.dataType).eval(EmptyRow))
+ val pivotAggs = namedAggExps.map { a =>
+ Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute,
castPivotValues)
+ .toAggregateExpression()
+ , "__pivot_" + a.sql)()
+ }
+ val secondAgg = Aggregate(groupByExprs, groupByExprs ++
pivotAggs, firstAgg)
+ val pivotAggAttribute = pivotAggs.map(_.toAttribute)
+ val pivotOutputs = pivotValues.zipWithIndex.flatMap { case
(value, i) =>
+ aggregates.zip(pivotAggAttribute).map { case (aggregate,
pivotAtt) =>
+ Alias(ExtractValue(pivotAtt, Literal(i), resolver),
outputName(value, aggregate))()
+ }
}
- aggregates.map { aggregate =>
- val filteredAggregate = aggregate.transformDown {
- // Assumption is the aggregate function ignores nulls. This
is true for all current
- // AggregateFunction's with the exception of First and Last
in their default mode
- // (which we handle) and possibly some Hive UDAF's.
- case First(expr, _) =>
- First(ifExpr(expr), Literal(true))
- case Last(expr, _) =>
- Last(ifExpr(expr), Literal(true))
- case a: AggregateFunction =>
- a.withNewChildren(a.children.map(ifExpr))
- }.transform {
- // We are duplicating aggregates that are now computing a
different value for each
- // pivot value.
- // TODO: Don't construct the physical container until after
analysis.
- case ae: AggregateExpression => ae.copy(resultId =
NamedExpression.newExprId)
+ Project(groupByExprs ++ pivotOutputs, secondAgg)
+ } else {
--- End diff --
Since we will decide which branch to use based on the datatypes, do we
still have enough test coverage for this else branch?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]