Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/20965#discussion_r183185144
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -254,6 +256,80 @@ case class HashAggregateExec(
""".stripMargin
}
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ context: CodegenContext,
+ aggregateExpression: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Seq[(String,
String, Expression)] = {
+ // `argMap` collects all the pairs of variable names and their types,
the first in the pair
+ // is a type name and the second is a variable name.
+ val argMap = mutable.Map[(String, String), Expression]()
+ val stack = mutable.Stack[Expression](aggregateExpression)
+ while (stack.nonEmpty) {
+ stack.pop() match {
+ case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodeGenerator.isJavaIdentifier(exprCode.value)) {
+ argMap += (CodeGenerator.javaType(e.dataType), exprCode.value)
-> e
+ }
+ if (CodeGenerator.isJavaIdentifier(exprCode.isNull)) {
+ argMap += ("boolean", exprCode.isNull) -> e
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+ case ref: BoundReference
+ if context.currentVars != null &&
context.currentVars(ref.ordinal) != null =>
+ val value = context.currentVars(ref.ordinal).value
+ val isNull = context.currentVars(ref.ordinal).isNull
+ if (CodeGenerator.isJavaIdentifier(value)) {
+ argMap += (CodeGenerator.javaType(ref.dataType), value) -> ref
+ }
+ if (CodeGenerator.isJavaIdentifier(isNull)) {
+ argMap += ("boolean", isNull) -> ref
+ }
+ case ref: BoundReference =>
+ argMap += ("InternalRow", context.INPUT_ROW) -> ref
+ case e =>
+ stack.pushAll(e.children)
+ }
+ }
+
+ argMap.map { case ((tpe, name), e) => (tpe, name, e) }.toSeq
+ }
+
+ // Splits aggregate code into small functions because JVMs does not
compile too long functions
+ private def splitAggregateExpressions(
--- End diff --
ok, I'll recheck.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]