[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r320541077 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -824,59 +944,158 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input +val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName) +// Computes start offsets for each aggregation function code +// in the underlying buffer row. +val bufferStartOffsets = { + val offsets = mutable.ArrayBuffer[Int]() + var curOffset = 0 + updateExprs.foreach { exprsForOneFunc => +offsets += curOffset +curOffset += exprsForOneFunc.length + } + offsets.toArray +} + val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => +bindReferences(updateExprsForOneFunc, inputAttr) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { -boundUpdateExpr.map(_.genCode(ctx)) + val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => +ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) +} } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => -val dt = updateExpr(i).dataType -CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + + val aggCodeBlocks = updateExprs.indices.map { i => +val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i) +val boundUpdateExprsForOneFunc = boundUpdateExprs(i) +val bufferOffset = bufferStartOffsets(i) + +// All the update code for aggregation buffers should be placed in the end +// of each aggregation function code. +val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) +} +code""" + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(rowBufferEvalsForOneFunc)} + |// update unsafe row buffer + |${updateRowBuffers.mkString("\n").trim} + """.stripMargin + } + + lazy val nonSplitAggCode = { +s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |${aggCodeBlocks.fold(EmptyBlock)(_ + _)} + """.stripMargin + } + + if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { +val maybeSplitCode = splitAggregateExpressions( + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + +maybeSplitCode.map { updateAggCode => + s""" + |// do aggregate + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |$updateAggCode + """.stripMargin +}.getOrElse { + nonSplitAggCode +} + } else { +nonSplitAggCode } - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} - """.stripMargin } val updateRowInHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => +bindReferences(updateExprsForOneFunc, inputAttr) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSu
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r319953081 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -824,59 +944,158 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input +val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName) +// Computes start offsets for each aggregation function code +// in the underlying buffer row. +val bufferStartOffsets = { + val offsets = mutable.ArrayBuffer[Int]() + var curOffset = 0 + updateExprs.foreach { exprsForOneFunc => +offsets += curOffset +curOffset += exprsForOneFunc.length + } + offsets.toArray +} + val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => +bindReferences(updateExprsForOneFunc, inputAttr) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { -boundUpdateExpr.map(_.genCode(ctx)) + val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => +ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) +} } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => -val dt = updateExpr(i).dataType -CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + + val aggCodeBlocks = updateExprs.indices.map { i => +val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i) +val boundUpdateExprsForOneFunc = boundUpdateExprs(i) +val bufferOffset = bufferStartOffsets(i) + +// All the update code for aggregation buffers should be placed in the end +// of each aggregation function code. +val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) +} +code""" + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(rowBufferEvalsForOneFunc)} + |// update unsafe row buffer + |${updateRowBuffers.mkString("\n").trim} + """.stripMargin + } + + lazy val nonSplitAggCode = { +s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |${aggCodeBlocks.fold(EmptyBlock)(_ + _)} + """.stripMargin + } + + if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { +val splitAggCode = splitAggregateExpressions( Review comment: ditto This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r319953013 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +261,153 @@ case class HashAggregateExec( """.stripMargin } + private def isValidParamLength(paramLength: Int): Boolean = { +// This config is only for testing +sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) match { + case null | "" => CodeGenerator.isValidParamLength(paramLength) + case validLength => paramLength <= validLength.toInt +} + } + + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggBufferUpdatingExprs: Seq[Seq[Expression]], + aggCodeBlocks: Seq[Block], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { +val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil } +if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { + // `SimpleExprValue`s cannot be used as an input variable for split functions, so + // we give up splitting functions if it exists in `subExprs`. + None +} else { + val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => +val inputVarsForOneFunc = aggExprsForOneFunc.map( + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq +val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) + +// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit +if (isValidParamLength(paramLength)) { + Some(inputVarsForOneFunc) +} else { + None +} + } + + // Checks if all the aggregate code can be split into pieces. + // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, + // we totally give up splitting aggregate code. + if (inputVars.forall(_.isDefined)) { +val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => + val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") + val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") + val doAggFuncName = ctx.addNewFunction(doAggFunc, +s""" + |private void $doAggFunc($argList) throws java.io.IOException { + | ${aggCodeBlocks(i)} + |} + """.stripMargin) + + val inputVariables = args.map(_.variableName).mkString(", ") + s"$doAggFuncName($inputVariables);" +} +Some(splitCodes.mkString("\n").trim) + } else { +val errMsg = "Failed to split aggregate code into small functions because the parameter " + + "length of at least one split function went over the JVM limit: " + + CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH +if (Utils.isTesting) { + throw new IllegalStateException(errMsg) +} else { + logInfo(errMsg) + None +} + } +} + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output -val updateExpr = aggregateExpressions.flatMap { e => +// To individually generate code for each aggregate function, an element in `updateExprs` holds +// all the expressions for the buffer of an aggregation function. +val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions case PartialMerge | Final => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } -ctx.currentVars = bufVars ++ input -val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) -val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) +ctx.currentVars = bufVars.flatten ++ input +val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => + bindReferences(updateExprsForOneFunc, inputAttrs) +} +val su
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r319953154 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -824,59 +944,158 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input +val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName) +// Computes start offsets for each aggregation function code +// in the underlying buffer row. +val bufferStartOffsets = { + val offsets = mutable.ArrayBuffer[Int]() + var curOffset = 0 + updateExprs.foreach { exprsForOneFunc => +offsets += curOffset +curOffset += exprsForOneFunc.length + } + offsets.toArray +} + val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => +bindReferences(updateExprsForOneFunc, inputAttr) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { -boundUpdateExpr.map(_.genCode(ctx)) + val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => +ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) +} } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => -val dt = updateExpr(i).dataType -CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + + val aggCodeBlocks = updateExprs.indices.map { i => +val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i) +val boundUpdateExprsForOneFunc = boundUpdateExprs(i) +val bufferOffset = bufferStartOffsets(i) + +// All the update code for aggregation buffers should be placed in the end +// of each aggregation function code. +val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) +} +code""" + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(rowBufferEvalsForOneFunc)} + |// update unsafe row buffer + |${updateRowBuffers.mkString("\n").trim} + """.stripMargin + } + + lazy val nonSplitAggCode = { +s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |${aggCodeBlocks.fold(EmptyBlock)(_ + _)} + """.stripMargin + } + + if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { +val splitAggCode = splitAggregateExpressions( + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + +splitAggCode.map { updateAggCode => + s""" + |// do aggregate + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |$updateAggCode + """.stripMargin +}.getOrElse { + nonSplitAggCode +} + } else { +nonSplitAggCode } - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} - """.stripMargin } val updateRowInHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => +bindReferences(updateExprsForOneFunc, inputAttr) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExp
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r319952695 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +261,153 @@ case class HashAggregateExec( """.stripMargin } + private def isValidParamLength(paramLength: Int): Boolean = { +// This config is only for testing +sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) match { + case null | "" => CodeGenerator.isValidParamLength(paramLength) + case validLength => paramLength <= validLength.toInt +} + } + + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. Review comment: let's mention that it returns None if we are not able to split the code. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r319951300 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +261,153 @@ case class HashAggregateExec( """.stripMargin } + private def isValidParamLength(paramLength: Int): Boolean = { +// This config is only for testing +sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) match { + case null | "" => CodeGenerator.isValidParamLength(paramLength) + case validLength => paramLength <= validLength.toInt +} + } + + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggBufferUpdatingExprs: Seq[Seq[Expression]], + aggCodeBlocks: Seq[Block], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { +val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil } +if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { + // `SimpleExprValue`s cannot be used as an input variable for split functions, so Review comment: is it because `ExprValue` is not a real tree format? Otherwise we can get all the referred variables of an expr value and put them on the parameter list. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r319950056 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +261,153 @@ case class HashAggregateExec( """.stripMargin } + private def isValidParamLength(paramLength: Int): Boolean = { +// This config is only for testing +sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) match { + case null | "" => CodeGenerator.isValidParamLength(paramLength) + case validLength => paramLength <= validLength.toInt Review comment: so this is the valid length, maybe a better config name is `spark.sql.HashAggregateExec.validParamLength` This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r319949113 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -174,8 +176,9 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used for aggregation without keys. - private var bufVars: Seq[ExprCode] = _ + // The variables are used as aggregation buffers and each aggregate function has one more ExprCode Review comment: `one more` -> `one or more` This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r318533829 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +260,148 @@ case class HashAggregateExec( """.stripMargin } + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggExprs: Seq[Seq[Expression]], Review comment: so it's actually `aggBufferUpdatingExprs`? This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r318533114 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +260,148 @@ case class HashAggregateExec( """.stripMargin } + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggExprs: Seq[Seq[Expression]], + makeSplitAggFunctions: => Seq[String], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { +val inputVars = aggExprs.map { aggExprsInAgg => + val inputVarsInAgg = aggExprsInAgg.map( +CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsInAgg) + + // Checks if a parameter length for the `aggExprsInAgg` does not go over the JVM limit + if (CodeGenerator.isValidParamLength(paramLength)) { +Some(inputVarsInAgg) + } else { +None + } +} + +// Checks if all the aggregate code can be split into pieces. +// If the parameter length of at lease one `aggExprsInAgg` goes over the limit, +// we totally give up splitting aggregate code. +if (inputVars.forall(_.isDefined)) { + val splitAggEvalCodes = makeSplitAggFunctions + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => +val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}") +val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") +val doAggValFuncName = ctx.addNewFunction(doAggVal, + s""" + | private void $doAggVal($argList) throws java.io.IOException { + | ${splitAggEvalCodes(i)} + | } + """.stripMargin) + +val inputVariables = args.map(_.variableName).mkString(", ") +s"$doAggValFuncName($inputVariables);" + } + Some(splitCodes.mkString("\n").trim) +} else { + val errMsg = "Failed to split aggregate code into small functions because the parameter " + +"length of at least one split function went over the JVM limit: " + +CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { +throw new IllegalStateException(errMsg) + } else { +logInfo(errMsg) +None + } +} + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output -val updateExpr = aggregateExpressions.flatMap { e => +val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions case PartialMerge | Final => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } -ctx.currentVars = bufVars ++ input -val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) -val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) +ctx.currentVars = bufVars.flatten ++ input +val boundUpdateExprs = updateExprs.map { updateExprsInAgg => + updateExprsInAgg.map(BindReferences.bindReference(_, inputAttrs)) +} +val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") -val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) +val aggVals = boundUpdateExprs.map { boundUpdateExprsInAgg => Review comment: ditto This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r318532966 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +260,148 @@ case class HashAggregateExec( """.stripMargin } + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggExprs: Seq[Seq[Expression]], + makeSplitAggFunctions: => Seq[String], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { +val inputVars = aggExprs.map { aggExprsInAgg => + val inputVarsInAgg = aggExprsInAgg.map( +CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsInAgg) + + // Checks if a parameter length for the `aggExprsInAgg` does not go over the JVM limit + if (CodeGenerator.isValidParamLength(paramLength)) { +Some(inputVarsInAgg) + } else { +None + } +} + +// Checks if all the aggregate code can be split into pieces. +// If the parameter length of at lease one `aggExprsInAgg` goes over the limit, +// we totally give up splitting aggregate code. +if (inputVars.forall(_.isDefined)) { + val splitAggEvalCodes = makeSplitAggFunctions + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => +val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}") +val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") +val doAggValFuncName = ctx.addNewFunction(doAggVal, + s""" + | private void $doAggVal($argList) throws java.io.IOException { + | ${splitAggEvalCodes(i)} + | } + """.stripMargin) + +val inputVariables = args.map(_.variableName).mkString(", ") +s"$doAggValFuncName($inputVariables);" + } + Some(splitCodes.mkString("\n").trim) +} else { + val errMsg = "Failed to split aggregate code into small functions because the parameter " + +"length of at least one split function went over the JVM limit: " + +CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { +throw new IllegalStateException(errMsg) + } else { +logInfo(errMsg) +None + } +} + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output -val updateExpr = aggregateExpressions.flatMap { e => +val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions case PartialMerge | Final => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } -ctx.currentVars = bufVars ++ input -val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) -val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) +ctx.currentVars = bufVars.flatten ++ input +val boundUpdateExprs = updateExprs.map { updateExprsInAgg => Review comment: `updateExprsInAgg` -> `updateExprsForOneFunc`? This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r318532780 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +260,148 @@ case class HashAggregateExec( """.stripMargin } + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggExprs: Seq[Seq[Expression]], + makeSplitAggFunctions: => Seq[String], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { +val inputVars = aggExprs.map { aggExprsInAgg => + val inputVarsInAgg = aggExprsInAgg.map( +CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsInAgg) + + // Checks if a parameter length for the `aggExprsInAgg` does not go over the JVM limit + if (CodeGenerator.isValidParamLength(paramLength)) { +Some(inputVarsInAgg) + } else { +None + } +} + +// Checks if all the aggregate code can be split into pieces. +// If the parameter length of at lease one `aggExprsInAgg` goes over the limit, +// we totally give up splitting aggregate code. +if (inputVars.forall(_.isDefined)) { + val splitAggEvalCodes = makeSplitAggFunctions + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => +val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}") +val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") +val doAggValFuncName = ctx.addNewFunction(doAggVal, + s""" + | private void $doAggVal($argList) throws java.io.IOException { + | ${splitAggEvalCodes(i)} + | } + """.stripMargin) + +val inputVariables = args.map(_.variableName).mkString(", ") +s"$doAggValFuncName($inputVariables);" + } + Some(splitCodes.mkString("\n").trim) +} else { + val errMsg = "Failed to split aggregate code into small functions because the parameter " + +"length of at least one split function went over the JVM limit: " + +CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { +throw new IllegalStateException(errMsg) + } else { +logInfo(errMsg) +None + } +} + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output -val updateExpr = aggregateExpressions.flatMap { e => +val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions case PartialMerge | Final => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } -ctx.currentVars = bufVars ++ input -val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) -val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) +ctx.currentVars = bufVars.flatten ++ input +val boundUpdateExprs = updateExprs.map { updateExprsInAgg => + updateExprsInAgg.map(BindReferences.bindReference(_, inputAttrs)) Review comment: why not follow the old code and write `bindReferences(updateExprsInAgg, inputAttrs)` This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r318529700 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +260,148 @@ case class HashAggregateExec( """.stripMargin } + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggExprs: Seq[Seq[Expression]], + makeSplitAggFunctions: => Seq[String], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { +val inputVars = aggExprs.map { aggExprsInAgg => + val inputVarsInAgg = aggExprsInAgg.map( +CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsInAgg) + + // Checks if a parameter length for the `aggExprsInAgg` does not go over the JVM limit + if (CodeGenerator.isValidParamLength(paramLength)) { +Some(inputVarsInAgg) + } else { +None + } +} + +// Checks if all the aggregate code can be split into pieces. +// If the parameter length of at lease one `aggExprsInAgg` goes over the limit, +// we totally give up splitting aggregate code. +if (inputVars.forall(_.isDefined)) { + val splitAggEvalCodes = makeSplitAggFunctions + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => +val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}") +val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") +val doAggValFuncName = ctx.addNewFunction(doAggVal, + s""" + | private void $doAggVal($argList) throws java.io.IOException { + | ${splitAggEvalCodes(i)} + | } + """.stripMargin) + +val inputVariables = args.map(_.variableName).mkString(", ") +s"$doAggValFuncName($inputVariables);" + } + Some(splitCodes.mkString("\n").trim) +} else { + val errMsg = "Failed to split aggregate code into small functions because the parameter " + +"length of at least one split function went over the JVM limit: " + +CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { +throw new IllegalStateException(errMsg) + } else { +logInfo(errMsg) +None + } +} + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output -val updateExpr = aggregateExpressions.flatMap { e => +val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions case PartialMerge | Final => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } -ctx.currentVars = bufVars ++ input -val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) -val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) +ctx.currentVars = bufVars.flatten ++ input +val boundUpdateExprs = updateExprs.map { updateExprsInAgg => + updateExprsInAgg.map(BindReferences.bindReference(_, inputAttrs)) +} +val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") -val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) +val aggVals = boundUpdateExprs.map { boundUpdateExprsInAgg => + ctx.withSubExprEliminationExprs(subExprs.states) { +boundUpdateExprsInAgg.map(_.genCode(ctx)) + } } -// aggregate buffer should be updated atomic -val updates = aggVals.zipWithIndex.map { case (ev, i) => + +lazy val nonSplitAggCode = { + // aggregate buffer should be updated atomically + val updates = aggVals.flatten.zip(bufVars.flatten).map { case (ev, bufVar) => +s""" + | ${bufVar.isNull} = ${ev.isNull}; + | ${bufVar.value} = ${ev.value}; + """.stripMargin + } s""" - | ${bufV
[GitHub] [spark] cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions
cloud-fan commented on a change in pull request #20965: [SPARK-21870][SQL] Split aggregation code into small functions URL: https://github.com/apache/spark/pull/20965#discussion_r318525635 ## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ## @@ -255,41 +260,148 @@ case class HashAggregateExec( """.stripMargin } + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggExprs: Seq[Seq[Expression]], Review comment: I have problems to understand this `aggExprs`, why it's a seq of seq? This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org