vladimirg-db commented on code in PR #53474:
URL: https://github.com/apache/spark/pull/53474#discussion_r2619861021
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala:
##########
@@ -897,100 +897,17 @@ class Analyzer(
}
// Check all aggregate expressions.
aggregates.foreach(checkValidAggregateExpression)
- // Check all pivot values are literal and match pivot column data type.
- val evalPivotValues = pivotValues.map { value =>
- val foldable = trimAliases(value).foldable
- if (!foldable) {
- throw QueryCompilationErrors.nonLiteralPivotValError(value)
- }
- if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
- throw QueryCompilationErrors.pivotValDataTypeMismatchError(value,
pivotColumn)
- }
- Cast(value, pivotColumn.dataType,
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
- }
- // Group-by expressions coming from SQL are implicit and need to be
deduced.
- val groupByExprs = groupByExprsOpt.getOrElse {
- val pivotColAndAggRefs = pivotColumn.references ++
AttributeSet(aggregates)
- child.output.filterNot(pivotColAndAggRefs.contains)
- }
- val singleAgg = aggregates.size == 1
- def outputName(value: Expression, aggregate: Expression): String = {
- val stringValue = value match {
- case n: NamedExpression => n.name
- case _ =>
- val utf8Value =
- Cast(value, StringType,
Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
- Option(utf8Value).map(_.toString).getOrElse("null")
- }
- if (singleAgg) {
- stringValue
- } else {
- val suffix = aggregate match {
- case n: NamedExpression => n.name
- case _ => toPrettySQL(aggregate)
- }
- stringValue + "_" + suffix
- }
- }
- if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
- // Since evaluating |pivotValues| if statements for each input row
can get slow this is an
- // alternate plan that instead uses two steps of aggregation.
- val namedAggExps: Seq[NamedExpression] = aggregates.map(a =>
Alias(a, a.sql)())
- val namedPivotCol = pivotColumn match {
- case n: NamedExpression => n
- case _ => Alias(pivotColumn, "__pivot_col")()
- }
- val bigGroup = groupByExprs :+ namedPivotCol
- val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
- val pivotAggs = namedAggExps.map { a =>
- Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute,
evalPivotValues)
- .toAggregateExpression()
- , "__pivot_" + a.sql)()
- }
- val groupByExprsAttr = groupByExprs.map(_.toAttribute)
- val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++
pivotAggs, firstAgg)
- val pivotAggAttribute = pivotAggs.map(_.toAttribute)
- val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value,
i) =>
- aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt)
=>
- Alias(ExtractValue(pivotAtt, Literal(i), resolver),
outputName(value, aggregate))()
- }
- }
- Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
- } else {
- val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap {
value =>
- def ifExpr(e: Expression) = {
- If(
- EqualNullSafe(
- pivotColumn,
- Cast(value, pivotColumn.dataType,
Some(conf.sessionLocalTimeZone))),
- e, Literal(null))
- }
- aggregates.map { aggregate =>
- val filteredAggregate = aggregate.transformDown {
- // Assumption is the aggregate function ignores nulls. This is
true for all current
- // AggregateFunction's with the exception of First and Last in
their default mode
- // (which we handle) and possibly some Hive UDAF's.
- case First(expr, _) =>
- First(ifExpr(expr), true)
- case Last(expr, _) =>
- Last(ifExpr(expr), true)
- case a: ApproximatePercentile =>
- // ApproximatePercentile takes two literals for accuracy and
percentage which
- // should not be wrapped by if-else.
- a.withNewChildren(ifExpr(a.first) :: a.second :: a.third ::
Nil)
- case a: AggregateFunction =>
- a.withNewChildren(a.children.map(ifExpr))
- }.transform {
- // We are duplicating aggregates that are now computing a
different value for each
- // pivot value.
- // TODO: Don't construct the physical container until after
analysis.
- case ae: AggregateExpression => ae.copy(resultId =
NamedExpression.newExprId)
- }
- Alias(filteredAggregate, outputName(value, aggregate))()
- }
- }
- Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
- }
+ PivotTransformer(
+ child = child,
+ pivotValues = pivotValues,
+ pivotColumn = pivotColumn,
+ groupByExpressionsOpt = groupByExprsOpt,
+ aggregates = aggregates,
+ childOutput = child.output,
+ newAlias = (child, name) =>
+ Alias(child, name.getOrElse(toPrettySQL(child)))(),
Review Comment:
`newAlias` is always called with `Some()`, so
`.getOrElse(toPrettySQL(child))` is not even needed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]