cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] 
Split aggregation code into small functions
URL: https://github.com/apache/spark/pull/20965#discussion_r319953013
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 ##########
 @@ -255,41 +261,153 @@ case class HashAggregateExec(
      """.stripMargin
   }
 
+  private def isValidParamLength(paramLength: Int): Boolean = {
+    // This config is only for testing
+    sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) 
match {
+      case null | "" => CodeGenerator.isValidParamLength(paramLength)
+      case validLength => paramLength <= validLength.toInt
+    }
+  }
+
+  // Splits aggregate code into small functions because the most of JVM 
implementations
+  // can not compile too long functions.
+  //
+  // Note: The difference from `CodeGenerator.splitExpressions` is that we 
define an individual
+  // function for each aggregation function (e.g., SUM and AVG). For example, 
in a query
+  // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
+  // for `SUM(a)` and `AVG(a)`.
+  private def splitAggregateExpressions(
+      ctx: CodegenContext,
+      aggNames: Seq[String],
+      aggBufferUpdatingExprs: Seq[Seq[Expression]],
+      aggCodeBlocks: Seq[Block],
+      subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
+    val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: 
s.isNull :: Nil }
+    if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
+      // `SimpleExprValue`s cannot be used as an input variable for split 
functions, so
+      // we give up splitting functions if it exists in `subExprs`.
+      None
+    } else {
+      val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
+        val inputVarsForOneFunc = aggExprsForOneFunc.map(
+          CodeGenerator.getLocalInputVariableValues(ctx, _, 
subExprs)).reduce(_ ++ _).toSeq
+        val paramLength = 
CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
+
+        // Checks if a parameter length for the `aggExprsForOneFunc` does not 
go over the JVM limit
+        if (isValidParamLength(paramLength)) {
+          Some(inputVarsForOneFunc)
+        } else {
+          None
+        }
+      }
+
+      // Checks if all the aggregate code can be split into pieces.
+      // If the parameter length of at lease one `aggExprsForOneFunc` goes 
over the limit,
+      // we totally give up splitting aggregate code.
+      if (inputVars.forall(_.isDefined)) {
+        val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
+          val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
+          val argList = args.map(v => s"${v.javaType.getName} 
${v.variableName}").mkString(", ")
+          val doAggFuncName = ctx.addNewFunction(doAggFunc,
+            s"""
+               |private void $doAggFunc($argList) throws java.io.IOException {
+               |  ${aggCodeBlocks(i)}
+               |}
+             """.stripMargin)
+
+          val inputVariables = args.map(_.variableName).mkString(", ")
+          s"$doAggFuncName($inputVariables);"
+        }
+        Some(splitCodes.mkString("\n").trim)
+      } else {
+        val errMsg = "Failed to split aggregate code into small functions 
because the parameter " +
+          "length of at least one split function went over the JVM limit: " +
+          CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
+        if (Utils.isTesting) {
+          throw new IllegalStateException(errMsg)
+        } else {
+          logInfo(errMsg)
+          None
+        }
+      }
+    }
+  }
+
   private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
     // only have DeclarativeAggregate
     val functions = 
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
     val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
-    val updateExpr = aggregateExpressions.flatMap { e =>
+    // To individually generate code for each aggregate function, an element 
in `updateExprs` holds
+    // all the expressions for the buffer of an aggregation function.
+    val updateExprs = aggregateExpressions.map { e =>
       e.mode match {
         case Partial | Complete =>
           
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
         case PartialMerge | Final =>
           
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
       }
     }
-    ctx.currentVars = bufVars ++ input
-    val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
-    val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+    ctx.currentVars = bufVars.flatten ++ input
+    val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
+      bindReferences(updateExprsForOneFunc, inputAttrs)
+    }
+    val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
     val effectiveCodes = subExprs.codes.mkString("\n")
-    val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
-      boundUpdateExpr.map(_.genCode(ctx))
+    val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
+      ctx.withSubExprEliminationExprs(subExprs.states) {
+        boundUpdateExprsForOneFunc.map(_.genCode(ctx))
+      }
     }
-    // aggregate buffer should be updated atomic
-    val updates = aggVals.zipWithIndex.map { case (ev, i) =>
-      s"""
-         | ${bufVars(i).isNull} = ${ev.isNull};
-         | ${bufVars(i).value} = ${ev.value};
+
+    val aggNames = functions.map(_.prettyName)
+    val aggCodeBlocks = bufferEvals.zipWithIndex.map { case 
(bufferEvalsForOneFunc, i) =>
+      val bufVarsForOneFunc = bufVars(i)
+      // All the update code for aggregation buffers should be placed in the 
end
+      // of each aggregation function code.
+      val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case 
(ev, bufVar) =>
+        s"""
+           |${bufVar.isNull} = ${ev.isNull};
+           |${bufVar.value} = ${ev.value};
+         """.stripMargin
+      }
+      code"""
+         |// do aggregate for ${aggNames(i)}
+         |// evaluate aggregate function
+         |${evaluateVariables(bufferEvalsForOneFunc)}
+         |// update aggregation buffers
+         |${updates.mkString("\n").trim}
        """.stripMargin
     }
-    s"""
-       | // do aggregate
-       | // common sub-expressions
-       | $effectiveCodes
-       | // evaluate aggregate function
-       | ${evaluateVariables(aggVals)}
-       | // update aggregation buffer
-       | ${updates.mkString("\n").trim}
-     """.stripMargin
+
+    lazy val nonSplitAggCode = {
+       s"""
+         |// do aggregate
+         |// common sub-expressions
+         |$effectiveCodes
+         |// evaluate aggregate functions and update aggregation buffers
+         |${aggCodeBlocks.fold(EmptyBlock)(_ + _)}
+       """.stripMargin
+    }
+
+    if (conf.codegenSplitAggregateFunc &&
+        aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
+      val splitAggCode = splitAggregateExpressions(
 
 Review comment:
   nit: `maybeSplitCode` or `splitCodeOpt`. It helps people to understand why 
there is a `map` call below.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to