Github user maropu commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20965#discussion_r179707996
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 ---
    @@ -254,6 +256,80 @@ case class HashAggregateExec(
          """.stripMargin
       }
     
    +  // Extracts all the input variable references for a given `aggExpr`. 
This result will be used
    +  // to split aggregation into small functions.
    +  private def getInputVariableReferences(
    +      context: CodegenContext,
    +      aggregateExpression: Expression,
    +      subExprs: Map[Expression, SubExprEliminationState]): Seq[(String, 
String, Expression)] = {
    +    // `argMap` collects all the pairs of variable names and their types, 
the first in the pair
    +    // is a type name and the second is a variable name.
    +    val argMap = mutable.Map[(String, String), Expression]()
    +    val stack = mutable.Stack[Expression](aggregateExpression)
    +    while (stack.nonEmpty) {
    +      stack.pop() match {
    +        case e if subExprs.contains(e) =>
    +          val exprCode = subExprs(e)
    +          if (CodeGenerator.isJavaIdentifier(exprCode.value)) {
    +            argMap += (CodeGenerator.javaType(e.dataType), exprCode.value) 
-> e
    +          }
    +          if (CodeGenerator.isJavaIdentifier(exprCode.isNull)) {
    +            argMap += ("boolean", exprCode.isNull) -> e
    +          }
    +          // Since the children possibly has common expressions, we push 
them here
    +          stack.pushAll(e.children)
    +        case ref: BoundReference
    +            if context.currentVars != null && 
context.currentVars(ref.ordinal) != null =>
    +          val value = context.currentVars(ref.ordinal).value
    +          val isNull = context.currentVars(ref.ordinal).isNull
    +          if (CodeGenerator.isJavaIdentifier(value)) {
    +            argMap += (CodeGenerator.javaType(ref.dataType), value) -> ref
    +          }
    +          if (CodeGenerator.isJavaIdentifier(isNull)) {
    +            argMap += ("boolean", isNull) -> ref
    +          }
    +        case ref: BoundReference =>
    +          argMap += ("InternalRow", context.INPUT_ROW) -> ref
    +        case e =>
    +          stack.pushAll(e.children)
    +      }
    +    }
    +
    +    argMap.map { case ((tpe, name), e) => (tpe, name, e) }.toSeq
    +  }
    +
    +  // Splits aggregate code into small functions because JVMs does not 
compile too long functions
    +  private def splitAggregateExpressions(
    --- End diff --
    
    that's true and I'll try to fix


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to