Github user stanzhai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19301#discussion_r140699522
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 ---
    @@ -72,11 +74,19 @@ object AggregateExpression {
           aggregateFunction: AggregateFunction,
           mode: AggregateMode,
           isDistinct: Boolean): AggregateExpression = {
    +    val state = if (aggregateFunction.resolved) {
    +      Seq(aggregateFunction.toString, aggregateFunction.dataType,
    +        aggregateFunction.nullable, mode, isDistinct)
    +    } else {
    +      Seq(aggregateFunction.toString, mode, isDistinct)
    +    }
    +    val hashCode = state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * 
a + b)
    +
         AggregateExpression(
           aggregateFunction,
           mode,
           isDistinct,
    -      NamedExpression.newExprId)
    +      ExprId(hashCode))
    --- End diff --
    
    I've tried to optimize in aggregate planner 
(https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala#L211).
    
    ```scala
          // A single aggregate expression might appear multiple times in 
resultExpressions.
          // In order to avoid evaluating an individual aggregate function 
multiple times, we'll
          // build a set of the distinct aggregate expressions and build a 
function which can
          // be used to re-write expressions so that they reference the single 
copy of the
          // aggregate function which actually gets computed.
          val aggregateExpressions = resultExpressions.flatMap { expr =>
            expr.collect {
              case agg: AggregateExpression =>
                val aggregateFunction = agg.aggregateFunction
                val state = if (aggregateFunction.resolved) {
                  Seq(aggregateFunction.toString, aggregateFunction.dataType,
                    aggregateFunction.nullable, agg.mode, agg.isDistinct)
                } else {
                  Seq(aggregateFunction.toString, agg.mode, agg.isDistinct)
                }
                val hashCode = state.map(Objects.hashCode).foldLeft(0)((a, b) 
=> 31 * a + b)
                (hashCode, agg)
            }
          }.groupBy(_._1).map { case (_, values) =>
            values.head._2
          }.toSeq
    ```
    
    But it's difficult to distinguish between different typed aggregators 
without expr id. Current solution can work well for all of aggregate functions.
    
    I'm not familiar with typed aggregators, any suggestions will be 
appreciated.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to