Github user hvanhovell commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9556#discussion_r44343490
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---
    @@ -146,148 +146,105 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         }
       }
     
    -  object HashAggregation extends Strategy {
    -    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    -      // Aggregations that can be performed in two phases, before and 
after the shuffle.
    -      case PartialAggregation(
    -          namedGroupingAttributes,
    -          rewrittenAggregateExpressions,
    -          groupingExpressions,
    -          partialComputation,
    -          child) if !canBeConvertedToNewAggregation(plan) =>
    -        execution.Aggregate(
    -          partial = false,
    -          namedGroupingAttributes,
    -          rewrittenAggregateExpressions,
    -          execution.Aggregate(
    -            partial = true,
    -            groupingExpressions,
    -            partialComputation,
    -            planLater(child))) :: Nil
    -
    -      case _ => Nil
    -    }
    -
    -    def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan 
match {
    -      case a: logical.Aggregate =>
    -        if (sqlContext.conf.useSqlAggregate2 && 
sqlContext.conf.codegenEnabled) {
    -          a.newAggregation.isDefined
    -        } else {
    -          Utils.checkInvalidAggregateFunction2(a)
    -          false
    -        }
    -      case _ => false
    -    }
    -
    -    def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
    -      exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
    -  }
    -
       /**
        * Used to plan the aggregate operator for expressions based on the 
AggregateFunction2 interface.
        */
       object Aggregation extends Strategy {
         def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    -      case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 &&
    -          sqlContext.conf.codegenEnabled =>
    -        val converted = p.newAggregation
    -        converted match {
    -          case None => Nil // Cannot convert to new aggregation code path.
    -          case Some(logical.Aggregate(groupingExpressions, 
resultExpressions, child)) =>
    -            // A single aggregate expression might appear multiple times 
in resultExpressions.
    -            // In order to avoid evaluating an individual aggregate 
function multiple times, we'll
    -            // build a set of the distinct aggregate expressions and build 
a function which can
    -            // be used to re-write expressions so that they reference the 
single copy of the
    -            // aggregate function which actually gets computed.
    -            val aggregateExpressions = resultExpressions.flatMap { expr =>
    -              expr.collect {
    -                case agg: AggregateExpression2 => agg
    -              }
    -            }.distinct
    -            // For those distinct aggregate expressions, we create a map 
from the
    -            // aggregate function to the corresponding attribute of the 
function.
    -            val aggregateFunctionToAttribute = aggregateExpressions.map { 
agg =>
    -              val aggregateFunction = agg.aggregateFunction
    -              val attribute = Alias(aggregateFunction, 
aggregateFunction.toString)().toAttribute
    -              (aggregateFunction, agg.isDistinct) -> attribute
    -            }.toMap
    -
    -            val (functionsWithDistinct, functionsWithoutDistinct) =
    -              aggregateExpressions.partition(_.isDistinct)
    -            if 
(functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
    -              // This is a sanity check. We should not reach here when we 
have multiple distinct
    -              // column sets (aggregate.NewAggregation will not match).
    -              sys.error(
    -                "Multiple distinct column sets are not supported by the 
new aggregation" +
    -                  "code path.")
    -            }
    +      case logical.Aggregate(groupingExpressions, resultExpressions, 
child) =>
    +        // A single aggregate expression might appear multiple times in 
resultExpressions.
    +        // In order to avoid evaluating an individual aggregate function 
multiple times, we'll
    +        // build a set of the distinct aggregate expressions and build a 
function which can
    +        // be used to re-write expressions so that they reference the 
single copy of the
    +        // aggregate function which actually gets computed.
    +        val aggregateExpressions = resultExpressions.flatMap { expr =>
    +          expr.collect {
    +            case agg: AggregateExpression => agg
    +          }
    +        }.distinct
    +        // For those distinct aggregate expressions, we create a map from 
the
    +        // aggregate function to the corresponding attribute of the 
function.
    +        val aggregateFunctionToAttribute = aggregateExpressions.map { agg 
=>
    +          val aggregateFunction = agg.aggregateFunction
    +          val attribute = Alias(aggregateFunction, 
aggregateFunction.toString)().toAttribute
    +          (aggregateFunction, agg.isDistinct) -> attribute
    +        }.toMap
    +
    +        val (functionsWithDistinct, functionsWithoutDistinct) =
    +          aggregateExpressions.partition(_.isDistinct)
    +        if 
(functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
    +          // This is a sanity check. We should not reach here when we have 
multiple distinct
    +          // column sets (aggregate.NewAggregation will not match).
    +          sys.error(
    +            "Multiple distinct column sets are not supported by the new 
aggregation" +
    +              "code path.")
    +        }
     
    -            val namedGroupingExpressions = groupingExpressions.map {
    -              case ne: NamedExpression => ne -> ne
    -              // If the expression is not a NamedExpressions, we add an 
alias.
    -              // So, when we generate the result of the operator, the 
Aggregate Operator
    -              // can directly get the Seq of attributes representing the 
grouping expressions.
    -              case other =>
    -                val withAlias = Alias(other, other.toString)()
    -                other -> withAlias
    -            }
    -            val groupExpressionMap = namedGroupingExpressions.toMap
    -
    -            // The original `resultExpressions` are a set of expressions 
which may reference
    -            // aggregate expressions, grouping column values, and 
constants. When aggregate operator
    -            // emits output rows, we will use `resultExpressions` to 
generate an output projection
    -            // which takes the grouping columns and final aggregate result 
buffer as input.
    -            // Thus, we must re-write the result expressions so that their 
attributes match up with
    -            // the attributes of the final result projection's input row:
    -            val rewrittenResultExpressions = resultExpressions.map { expr 
=>
    -              expr.transformDown {
    -                case AggregateExpression2(aggregateFunction, _, 
isDistinct) =>
    -                  // The final aggregation buffer's attributes will be 
`finalAggregationAttributes`,
    -                  // so replace each aggregate expression by its 
corresponding attribute in the set:
    -                  aggregateFunctionToAttribute(aggregateFunction, 
isDistinct)
    -                case expression =>
    -                  // Since we're using `namedGroupingAttributes` to 
extract the grouping key
    -                  // columns, we need to replace grouping key expressions 
with their corresponding
    -                  // attributes. We do not rely on the equality check at 
here since attributes may
    -                  // differ cosmetically. Instead, we use semanticEquals.
    -                  groupExpressionMap.collectFirst {
    -                    case (expr, ne) if expr semanticEquals expression => 
ne.toAttribute
    -                  }.getOrElse(expression)
    -              }.asInstanceOf[NamedExpression]
    +        val namedGroupingExpressions = groupingExpressions.map {
    +          case ne: NamedExpression => ne -> ne
    +          // If the expression is not a NamedExpressions, we add an alias.
    +          // So, when we generate the result of the operator, the 
Aggregate Operator
    +          // can directly get the Seq of attributes representing the 
grouping expressions.
    +          case other =>
    +            val withAlias = Alias(other, other.toString)()
    +            other -> withAlias
    +        }
    +        val groupExpressionMap = namedGroupingExpressions.toMap
    +
    +        // The original `resultExpressions` are a set of expressions which 
may reference
    +        // aggregate expressions, grouping column values, and constants. 
When aggregate operator
    +        // emits output rows, we will use `resultExpressions` to generate 
an output projection
    +        // which takes the grouping columns and final aggregate result 
buffer as input.
    +        // Thus, we must re-write the result expressions so that their 
attributes match up with
    +        // the attributes of the final result projection's input row:
    +        val rewrittenResultExpressions = resultExpressions.map { expr =>
    +          expr.transformDown {
    +            case AggregateExpression(aggregateFunction, _, isDistinct) =>
    +              // The final aggregation buffer's attributes will be 
`finalAggregationAttributes`,
    +              // so replace each aggregate expression by its corresponding 
attribute in the set:
    +              aggregateFunctionToAttribute(aggregateFunction, isDistinct)
    +            case expression =>
    +              // Since we're using `namedGroupingAttributes` to extract 
the grouping key
    +              // columns, we need to replace grouping key expressions with 
their corresponding
    +              // attributes. We do not rely on the equality check at here 
since attributes may
    +              // differ cosmetically. Instead, we use semanticEquals.
    +              groupExpressionMap.collectFirst {
    +                case (expr, ne) if expr semanticEquals expression => 
ne.toAttribute
    +              }.getOrElse(expression)
    +          }.asInstanceOf[NamedExpression]
    +        }
    +
    +        val aggregateOperator =
    +          if 
(aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
    +            if (functionsWithDistinct.nonEmpty) {
    +              sys.error("Distinct columns cannot exist in Aggregate 
operator containing " +
    +                "aggregate functions which don't support partial 
aggregation.")
    +            } else {
    +              aggregate.Utils.planAggregateWithoutPartial(
    +                namedGroupingExpressions.map(_._2),
    +                aggregateExpressions,
    +                aggregateFunctionToAttribute,
    +                rewrittenResultExpressions,
    +                planLater(child))
                 }
    +          } else if (functionsWithDistinct.isEmpty) {
    +            aggregate.Utils.planAggregateWithoutDistinct(
    +              namedGroupingExpressions.map(_._2),
    +              aggregateExpressions,
    +              aggregateFunctionToAttribute,
    +              rewrittenResultExpressions,
    +              planLater(child))
    +          } else {
    +            aggregate.Utils.planAggregateWithOneDistinct(
    --- End diff --
    
    @yhuai I was thinking the same thing. This would make it easier to 
benchmark the different paths. We could address this in a follow-up PR.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to