Github user yhuai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/11583#discussion_r57539427
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 ---
    @@ -309,38 +309,64 @@ class Analyzer(
     
       object ResolvePivot extends Rule[LogicalPlan] {
         def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    -      case p: Pivot if !p.childrenResolved | 
!p.aggregates.forall(_.resolved) => p
    +      case p: Pivot if !p.childrenResolved | 
!p.aggregates.forall(_.resolved)
    +        | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
           case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, 
child) =>
             val singleAgg = aggregates.size == 1
    -        val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { 
value =>
    -          def ifExpr(expr: Expression) = {
    -            If(EqualTo(pivotColumn, value), expr, Literal(null))
    +        def outputName(value: Literal, aggregate: Expression): String = {
    +          if (singleAgg) value.toString else value + "_" + aggregate.sql
    +        }
    +        if (pivotValues.length >= 10
    +          && aggregates.forall(a => 
PivotFirst.supportsDataType(a.dataType))) {
    +          // Since evaluating |pivotValues| if statements for each input 
row can get slow this is an
    +          // alternate plan that instead uses two steps of aggregation.
    +          val namedAggExps: Seq[NamedExpression] = aggregates.map(a => 
Alias(a, a.sql)())
    +          val namedPivotCol = pivotColumn match {
    +            case n: NamedExpression => n
    +            case _ => Alias(pivotColumn, "__pivot_col")()
    +          }
    +          val bigGroup = groupByExprs :+ namedPivotCol
    +          val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, 
child)
    +          val castPivotValues = pivotValues.map(Cast(_, 
pivotColumn.dataType).eval(EmptyRow))
    +          val pivotAggs = namedAggExps.map { a =>
    +            Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, 
castPivotValues)
    +              .toAggregateExpression()
    +            , "__pivot_" + a.sql)()
               }
    -          aggregates.map { aggregate =>
    -            val filteredAggregate = aggregate.transformDown {
    -              // Assumption is the aggregate function ignores nulls. This 
is true for all current
    -              // AggregateFunction's with the exception of First and Last 
in their default mode
    -              // (which we handle) and possibly some Hive UDAF's.
    -              case First(expr, _) =>
    -                First(ifExpr(expr), Literal(true))
    -              case Last(expr, _) =>
    -                Last(ifExpr(expr), Literal(true))
    -              case a: AggregateFunction =>
    -                a.withNewChildren(a.children.map(ifExpr))
    +          val secondAgg = Aggregate(groupByExprs, groupByExprs ++ 
pivotAggs, firstAgg)
    +          val pivotAggAttribute = pivotAggs.map(_.toAttribute)
    +          val pivotOutputs = pivotValues.zipWithIndex.flatMap { case 
(value, i) =>
    +            aggregates.zip(pivotAggAttribute).map { case (aggregate, 
pivotAtt) =>
    +              Alias(ExtractValue(pivotAtt, Literal(i), resolver), 
outputName(value, aggregate))()
                 }
    -            if (filteredAggregate.fastEquals(aggregate)) {
    -              throw new AnalysisException(
    -                s"Aggregate expression required for pivot, found 
'$aggregate'")
    +          }
    +          Project(groupByExprs ++ pivotOutputs, secondAgg)
    +        } else {
    +          val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap 
{ value =>
    +            def ifExpr(expr: Expression) = {
    +              If(EqualTo(pivotColumn, value), expr, Literal(null))
    +            }
    +            aggregates.map { aggregate =>
    +              val filteredAggregate = aggregate.transformDown {
    +                // Assumption is the aggregate function ignores nulls. 
This is true for all current
    +                // AggregateFunction's with the exception of First and 
Last in their default mode
    +                // (which we handle) and possibly some Hive UDAF's.
    +                case First(expr, _) =>
    +                  First(ifExpr(expr), Literal(true))
    +                case Last(expr, _) =>
    +                  Last(ifExpr(expr), Literal(true))
    +                case a: AggregateFunction =>
    +                  a.withNewChildren(a.children.map(ifExpr))
    +              }
    +              if (filteredAggregate.fastEquals(aggregate)) {
    +                throw new AnalysisException(
    +                  s"Aggregate expression required for pivot, found 
'$aggregate'")
    +              }
    +              Alias(filteredAggregate, outputName(value, aggregate))()
                 }
    -            val name = if (singleAgg) value.toString else value + "_" + 
aggregate.sql
    -            Alias(filteredAggregate, name)()
               }
    +          Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
             }
    -        val newGroupByExprs = groupByExprs.map {
    -          case UnresolvedAlias(e, _) => e
    -          case e => e
    -        }
    --- End diff --
    
    This map is not needed anymore?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to