cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] 
Split aggregation code into small functions
URL: https://github.com/apache/spark/pull/20965#discussion_r318532780
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 ##########
 @@ -255,41 +260,148 @@ case class HashAggregateExec(
      """.stripMargin
   }
 
+  // Splits aggregate code into small functions because the most of JVM 
implementations
+  // can not compile too long functions.
+  //
+  // Note: The difference from `CodeGenerator.splitExpressions` is that we 
define an individual
+  // function for each aggregation function (e.g., SUM and AVG). For example, 
in a query
+  // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
+  // for `SUM(a)` and `AVG(a)`.
+  private def splitAggregateExpressions(
+      ctx: CodegenContext,
+      aggNames: Seq[String],
+      aggExprs: Seq[Seq[Expression]],
+      makeSplitAggFunctions: => Seq[String],
+      subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
+    val inputVars = aggExprs.map { aggExprsInAgg =>
+      val inputVarsInAgg = aggExprsInAgg.map(
+        CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ 
++ _).toSeq
+      val paramLength = 
CodeGenerator.calculateParamLengthFromExprValues(inputVarsInAgg)
+
+      // Checks if a parameter length for the `aggExprsInAgg` does not go over 
the JVM limit
+      if (CodeGenerator.isValidParamLength(paramLength)) {
+        Some(inputVarsInAgg)
+      } else {
+        None
+      }
+    }
+
+    // Checks if all the aggregate code can be split into pieces.
+    // If the parameter length of at lease one `aggExprsInAgg` goes over the 
limit,
+    // we totally give up splitting aggregate code.
+    if (inputVars.forall(_.isDefined)) {
+      val splitAggEvalCodes = makeSplitAggFunctions
+      val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
+        val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}")
+        val argList = args.map(v => s"${v.javaType.getName} 
${v.variableName}").mkString(", ")
+        val doAggValFuncName = ctx.addNewFunction(doAggVal,
+          s"""
+             | private void $doAggVal($argList) throws java.io.IOException {
+             |   ${splitAggEvalCodes(i)}
+             | }
+           """.stripMargin)
+
+        val inputVariables = args.map(_.variableName).mkString(", ")
+        s"$doAggValFuncName($inputVariables);"
+      }
+      Some(splitCodes.mkString("\n").trim)
+    } else {
+      val errMsg = "Failed to split aggregate code into small functions 
because the parameter " +
+        "length of at least one split function went over the JVM limit: " +
+        CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
+      if (Utils.isTesting) {
+        throw new IllegalStateException(errMsg)
+      } else {
+        logInfo(errMsg)
+        None
+      }
+    }
+  }
+
   private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
     // only have DeclarativeAggregate
     val functions = 
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
     val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
-    val updateExpr = aggregateExpressions.flatMap { e =>
+    val updateExprs = aggregateExpressions.map { e =>
       e.mode match {
         case Partial | Complete =>
           
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
         case PartialMerge | Final =>
           
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
       }
     }
-    ctx.currentVars = bufVars ++ input
-    val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
-    val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+    ctx.currentVars = bufVars.flatten ++ input
+    val boundUpdateExprs = updateExprs.map { updateExprsInAgg =>
+      updateExprsInAgg.map(BindReferences.bindReference(_, inputAttrs))
 
 Review comment:
   why not follow the old code and write `bindReferences(updateExprsInAgg, 
inputAttrs)`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to