[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r179391181
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -257,6 +259,78 @@ case class HashAggregateExec(
""".stripMargin
}
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ context: CodegenContext,
+ aggregateExpression: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggregateExpression)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
--- End diff --
hey, good news! Thanks for letting me know ;)
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user rednaxelafx commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r179367236
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -257,6 +259,78 @@ case class HashAggregateExec(
""".stripMargin
}
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ context: CodegenContext,
+ aggregateExpression: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggregateExpression)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
--- End diff --
Once we have @viirya 's https://github.com/apache/spark/pull/20043 merged
we won't need the ugly `CodegenContext.isJavaIdentifier` hack any more >_<|||
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r178716133
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -825,52 +924,92 @@ case class HashAggregateExec(
ctx.currentVars = new
Array[ExprCode](aggregateBufferAttributes.length) ++ input
val updateRowInRegularHashMap: String = {
- ctx.INPUT_ROW = unsafeRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
--- End diff --
We need this copy because:
https://github.com/apache/spark/pull/19082#discussion_r143326742
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r163305111
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -825,52 +924,92 @@ case class HashAggregateExec(
ctx.currentVars = new
Array[ExprCode](aggregateBufferAttributes.length) ++ input
val updateRowInRegularHashMap: String = {
- ctx.INPUT_ROW = unsafeRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
--- End diff --
why does this matter? We should avoid unnecessary data copy as possible as
we can.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu closed the pull request at: https://github.com/apache/spark/pull/19082 --- - To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156249393 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -269,28 +343,50 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } -ctx.currentVars = bufVars ++ input + +// We need to copy the aggregation buffer to local variables first because each aggregate +// function directly updates the buffer when it finishes. --- End diff -- just FYI: we must need local copys from this discussions, too https://github.com/apache/spark/pull/19865 --- - To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156229346
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -863,25 +984,43 @@ case class HashAggregateExec(
}
val updateRowInUnsafeRowMap: String = {
- ctx.INPUT_ROW = unsafeRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
+ val localRowBuffer = ctx.freshName("localUnsafeRowBuffer")
+ val initLocalRowBuffer = s"InternalRow $localRowBuffer =
$unsafeRowBuffer.copy();"
+
+ ctx.INPUT_ROW = localRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_,
inputAttr))
val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals =
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
- val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
+
+ val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
val dt = updateExpr(i).dataType
-ctx.updateColumn(unsafeRowBuffer, dt, i, ev,
updateExpr(i).nullable)
+val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i,
ev, updateExpr(i).nullable)
+s"""
+ | // evaluate aggregate function
+ | ${ev.code}
+ | // update unsafe row buffer
+ | $updateColumnCode
+ """.stripMargin
}
+
+ val updateAggValCode = splitAggregateExpressions(
+ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states,
+Seq(("InternalRow", unsafeRowBuffer)))
--- End diff --
ok
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156184646
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -863,25 +984,43 @@ case class HashAggregateExec(
}
val updateRowInUnsafeRowMap: String = {
- ctx.INPUT_ROW = unsafeRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
+ val localRowBuffer = ctx.freshName("localUnsafeRowBuffer")
+ val initLocalRowBuffer = s"InternalRow $localRowBuffer =
$unsafeRowBuffer.copy();"
+
+ ctx.INPUT_ROW = localRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_,
inputAttr))
val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals =
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
- val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
+
+ val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
val dt = updateExpr(i).dataType
-ctx.updateColumn(unsafeRowBuffer, dt, i, ev,
updateExpr(i).nullable)
+val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i,
ev, updateExpr(i).nullable)
+s"""
+ | // evaluate aggregate function
+ | ${ev.code}
+ | // update unsafe row buffer
+ | $updateColumnCode
+ """.stripMargin
}
+
+ val updateAggValCode = splitAggregateExpressions(
+ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states,
+Seq(("InternalRow", unsafeRowBuffer)))
--- End diff --
```
ctx,
boundUpdateExpr,
evalAndUpdateCodes,
subExprs.states,
Seq(("InternalRow", unsafeRowBuffer)))
```
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156092874
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -863,25 +984,43 @@ case class HashAggregateExec(
}
val updateRowInUnsafeRowMap: String = {
- ctx.INPUT_ROW = unsafeRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
+ val localRowBuffer = ctx.freshName("localUnsafeRowBuffer")
+ val initLocalRowBuffer = s"InternalRow $localRowBuffer =
$unsafeRowBuffer.copy();"
+
+ ctx.INPUT_ROW = localRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_,
inputAttr))
val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals =
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
- val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
+
+ val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
val dt = updateExpr(i).dataType
-ctx.updateColumn(unsafeRowBuffer, dt, i, ev,
updateExpr(i).nullable)
+val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i,
ev, updateExpr(i).nullable)
+s"""
+ | // evaluate aggregate function
+ | ${ev.code}
+ | // update unsafe row buffer
+ | $updateColumnCode
+ """.stripMargin
}
+
+ val updateAggValCode = splitAggregateExpressions(
+ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states,
+Seq(("InternalRow", unsafeRowBuffer)))
--- End diff --
Need more indents here?
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156092752
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -256,6 +258,85 @@ case class HashAggregateExec(
""".stripMargin
}
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits aggregate code into small functions because JVMs does not
compile too long functions
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum length of parameters in non-static Java methods is
254, but a parameter of
+ // type long or double contributes two units to the length. So, this
method gives up
+ // splitting the code if the parameter length goes over 127.
+ val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++
otherArgs).toSeq
+
+ // This is for testing/benchmarking only
+ val maxParamNumInJavaMethod =
+
sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
--- End diff --
ok
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156002166
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -805,26 +908,44 @@ case class HashAggregateExec(
def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = {
- ctx.INPUT_ROW = fastRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
+ val localRowBuffer = ctx.freshName("localFastRowBuffer")
+ val initLocalRowBuffer = s"InternalRow $localRowBuffer =
$fastRowBuffer.copy();"
+
+ ctx.INPUT_ROW = localRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_,
inputAttr))
val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
- val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
+
+ val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev,
i) =>
val dt = updateExpr(i).dataType
-ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable,
isVectorized)
+val updateColumnCode = ctx.updateColumn(
+ fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized)
+s"""
+ | // evaluate aggregate function
+ | ${ev.code}
+ | // update fast row
+ | $updateColumnCode
+ """.stripMargin
}
+
+ val updateAggValCode = splitAggregateExpressions(
+ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states,
+Seq(("InternalRow", fastRowBuffer)))
--- End diff --
indents
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r15593
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
---
@@ -1070,6 +1071,24 @@ class CodegenContext {
}
}
+object CodegenContext {
+
+ private val javaKeywords = Set(
+"abstract", "assert", "boolean", "break", "byte", "case", "catch",
"char", "class", "const",
+"continue", "default", "do", "double", "else", "extends", "false",
"final", "finally", "float",
+"for", "goto", "if", "implements", "import", "instanceof", "int",
"interface", "long", "native",
+"new", "null", "package", "private", "protected", "public", "return",
"short", "static",
+"strictfp", "super", "switch", "synchronized", "this", "throw",
"throws", "transient", "true",
+"try", "void", "volatile", "while"
+ )
+
+ def isJavaIdentifier(str: String): Boolean = str match {
+case null | "" => false
+case _ => !javaKeywords.contains(str) &&
isJavaIdentifierStart(str.charAt(0)) &&
+ (1 until str.length).forall(i => isJavaIdentifierPart(str.charAt(i)))
+ }
--- End diff --
```Scala
/**
* Returns true if the given `str` is a valid java identifier.
*/
def isJavaIdentifier(str: String): Boolean = str match {
case null | "" =>
false
case _ =>
!javaKeywords.contains(str) && isJavaIdentifierStart(str.charAt(0)) &&
(1 until str.length).forall(i =>
isJavaIdentifierPart(str.charAt(i)))
}
```
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156003342
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -256,6 +258,85 @@ case class HashAggregateExec(
""".stripMargin
}
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits aggregate code into small functions because JVMs does not
compile too long functions
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum length of parameters in non-static Java methods is
254, but a parameter of
+ // type long or double contributes two units to the length. So, this
method gives up
+ // splitting the code if the parameter length goes over 127.
+ val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++
otherArgs).toSeq
+
+ // This is for testing/benchmarking only
+ val maxParamNumInJavaMethod =
+
sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
--- End diff --
Let us introduce an internal SQLConf. If the number is high enough, we can
disable this feature.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156000426
--- Diff:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
---
@@ -380,4 +380,19 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd,
expected: $expectedAnd")
}
}
+
+ test("SPARK-21870 check if CodegenContext.isJavaIdentifier works
correctly") {
+assert(CodegenContext.isJavaIdentifier("agg_value") === true)
+assert(CodegenContext.isJavaIdentifier("agg_value1") === true)
+assert(CodegenContext.isJavaIdentifier("bhj_value4") === true)
+assert(CodegenContext.isJavaIdentifier("smj_value6") === true)
+assert(CodegenContext.isJavaIdentifier("rdd_value7") === true)
+assert(CodegenContext.isJavaIdentifier("scan_isNull") === true)
+assert(CodegenContext.isJavaIdentifier("test") === true)
+assert(CodegenContext.isJavaIdentifier("true") === false)
+assert(CodegenContext.isJavaIdentifier("false") === false)
+assert(CodegenContext.isJavaIdentifier("390239") === false)
+assert(CodegenContext.isJavaIdentifier(literal) === false)
+assert(CodegenContext.isJavaIdentifier(double) === false)
--- End diff --
```Scala
import CodegenContext.isJavaIdentifier
// positive cases
assert(isJavaIdentifier("agg_value"))
assert(isJavaIdentifier("agg_value1"))
assert(isJavaIdentifier("bhj_value4"))
assert(isJavaIdentifier("smj_value6"))
assert(isJavaIdentifier("rdd_value7"))
assert(isJavaIdentifier("scan_isNull"))
assert(isJavaIdentifier("test"))
// negative cases
assert(!isJavaIdentifier("true"))
assert(!isJavaIdentifier("false"))
assert(!isJavaIdentifier("390239"))
assert(!isJavaIdentifier(literal))
assert(!isJavaIdentifier(double))
```
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r156002134
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -863,25 +984,43 @@ case class HashAggregateExec(
}
val updateRowInUnsafeRowMap: String = {
- ctx.INPUT_ROW = unsafeRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
+ val localRowBuffer = ctx.freshName("localUnsafeRowBuffer")
+ val initLocalRowBuffer = s"InternalRow $localRowBuffer =
$unsafeRowBuffer.copy();"
+
+ ctx.INPUT_ROW = localRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_,
inputAttr))
val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals =
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
- val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
+
+ val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map {
case (ev, i) =>
val dt = updateExpr(i).dataType
-ctx.updateColumn(unsafeRowBuffer, dt, i, ev,
updateExpr(i).nullable)
+val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i,
ev, updateExpr(i).nullable)
+s"""
+ | // evaluate aggregate function
+ | ${ev.code}
+ | // update unsafe row buffer
+ | $updateColumnCode
+ """.stripMargin
}
+
+ val updateAggValCode = splitAggregateExpressions(
+ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states,
+Seq(("InternalRow", unsafeRowBuffer)))
--- End diff --
indents.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143903396
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in non-static Java methods is
254, so this method gives
+ // up splitting the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
--- End diff --
@kiszk Thanks for pinging me. I've updated the similar check in #18931 too.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143897243
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in non-static Java methods is
254, so this method gives
+ // up splitting the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
--- End diff --
@a10y Thanks for the info. Very helpful.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143897031
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in non-static Java methods is
254, so this method gives
+ // up splitting the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
--- End diff --
I just fixed the code to give up splitting if the length goes over 127
because IIUC the current implemented aggregate functions in spark do not go
over the limit. I feel it is some complicated to check types there...
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143895917
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in non-static Java methods is
254, so this method gives
+ // up splitting the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
--- End diff --
@a10y Good catch. You are right. We have 254 slots. Each Long or double
takes two slots.
We need to check type of parameters, too. cc: @maropu, @viirya
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143876358
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in non-static Java methods is
254, so this method gives
+ // up splitting the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
--- End diff --
Aha, I'll recheck this. Thanks.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user a10y commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143836700
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in non-static Java methods is
254, so this method gives
+ // up splitting the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
--- End diff --
Unless you wanna check the types of all the parameters you might be better
off halving this to 127 parameters for the worst case. Though I'm not sure how
many codegens this affects...
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user a10y commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143836110
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in non-static Java methods is
254, so this method gives
+ // up splitting the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
--- End diff --
If you read the spec closely, it actually says that there are 255 slots,
where 1 slot is taken by **this** and 2 slots each are taken up by **long** and
**double** parameters.
https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.3.3
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143359416
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -797,26 +904,44 @@ case class HashAggregateExec(
def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = {
- ctx.INPUT_ROW = fastRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
+ val localRowBuffer = ctx.freshName("localFastRowBuffer")
+ val initLocalRowBuffer = s"InternalRow $localRowBuffer =
$fastRowBuffer.copy();"
--- End diff --
I just passed the local variable as each function argument;
```
/* 329 */ // do aggregate
/* 330 */ // copy aggregation row buffer to the local
/* 331 */ InternalRow agg_localFastRowBuffer =
agg_fastAggBuffer.copy();
/* 332 */ // common sub-expressions
/* 333 */ boolean agg_isNull27 = false;
/* 334 */ long agg_value30 = -1L;
/* 335 */ if (!false) {
/* 336 */ agg_value30 = (long) inputadapter_value;
/* 337 */ }
/* 338 */ // process aggregate functions to update aggregation
buffer
/* 339 */ agg_doAggregateVal_add2(inputadapter_value, agg_value30,
agg_fastAggBuffer, agg_localFastRowBuffer, agg_isNull27);
/* 340 */ agg_doAggregateVal_add3(inputadapter_value, agg_value30,
agg_fastAggBuffer, agg_localFastRowBuffer, agg_isNull27);
/* 341 */ agg_doAggregateVal_if1(inputadapter_value, agg_value30,
agg_fastAggBuffer, agg_localFastRowBuffer, agg_isNull27);
/* 342 */
```
Since each split function directly updates an input row, we need to copy it
to the local so that all the split functions can reference the old state.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143326742
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -797,26 +904,44 @@ case class HashAggregateExec(
def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = {
- ctx.INPUT_ROW = fastRowBuffer
+ // We need to copy the aggregation row buffer to a local row first
because each aggregate
+ // function directly updates the buffer when it finishes.
+ val localRowBuffer = ctx.freshName("localFastRowBuffer")
+ val initLocalRowBuffer = s"InternalRow $localRowBuffer =
$fastRowBuffer.copy();"
--- End diff --
Why we need to copy the row buffer? You let `updateExpr` bound to the local
copied row buffer, but the evaluation is happened in split functions. Isn't
possible the `updateExpr` can't find the local variable of the copied row
buffer in the functions?
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143218854
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
---
@@ -944,6 +945,24 @@ class CodegenContext {
}
}
+object CodegenContext {
+
+ private val javaKeywords = Set(
--- End diff --
cc: @rednaxelafx
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r143216395
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
---
@@ -944,6 +945,24 @@ class CodegenContext {
}
}
+object CodegenContext {
+
+ private val javaKeywords = Set(
--- End diff --
Do we need add `enum`?
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136900835
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,89 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // Extracts all the input variable references for a given `aggExpr`.
This result will be used
+ // to split aggregation into small functions.
+ private def getInputVariableReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+// `argSet` collects all the pairs of variable names and their types,
the first in the pair is
+// a type name and the second is a variable name.
+val argSet = mutable.Set[(String, String)]()
+val stack = mutable.Stack[Expression](aggExpr)
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ if (CodegenContext.isJavaIdentifier(exprCode.value)) {
+argSet += ((ctx.javaType(e.dataType), exprCode.value))
+ }
+ if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
+argSet += (("boolean", exprCode.isNull))
+ }
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val value = ctx.currentVars(ref.ordinal).value
+ val isNull = ctx.currentVars(ref.ordinal).isNull
+ if (CodegenContext.isJavaIdentifier(value)) {
+argSet += ((ctx.javaType(ref.dataType), value))
+ }
+ if (CodegenContext.isJavaIdentifier(isNull)) {
+argSet += (("boolean", isNull))
+ }
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in Java methods is 255, so this
method gives up splitting
+ // the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
+ // descriptor, where the limit includes one unit for this in the
case of instance
+ // or interface method invocations.
+ val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++
otherArgs).toSeq
+
+ // This is for testing/benchmarking only
+ val maxParamNumInJavaMethod =
+
sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
+case null | "" => 255
--- End diff --
If line 314 uses `<=`, this should be 254. In the previous commit, `<` is
used.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136518024 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r --- End diff -- Good suggestion and I'm also looking for other better one. I'll try to fix. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at [email protected] or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136517567
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
+
+ // Returns true if a given name id belongs to this `CodegenContext`
+ private def isVariable(nameId: String): Boolean = nameId match {
+case variableName() => true
+case _ => false
+ }
+
+ // Extracts all the outer references for a given `aggExpr`. This result
will be used to split
+ // aggregation into small functions.
+ private def getOuterReferences(
--- End diff --
ok, I'll rename this.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user rednaxelafx commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136506452
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
+
+ // Returns true if a given name id belongs to this `CodegenContext`
+ private def isVariable(nameId: String): Boolean = nameId match {
+case variableName() => true
+case _ => false
+ }
+
+ // Extracts all the outer references for a given `aggExpr`. This result
will be used to split
+ // aggregation into small functions.
+ private def getOuterReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+val stack = mutable.Stack[Expression](aggExpr)
+val argSet = mutable.Set[(String, String)]()
+val addIfNotLiteral = (value: String, tpe: String) => {
--- End diff --
Hmm. Just a cosmetic style comment: I would have declared `addIfNotLiteral
` with a `def` instead of making it a `scala.Function2[String, String, Unit]`.
BTW, can we add a comment to `val argSet` for what those two fields of the
`Tuple2[String, String]` means? And then also make this `addIfNotLiteral`
function take the arguments in the same order as the tuple.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user rednaxelafx commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136506046
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
--- End diff --
I know the regular expression is tempting, but there's actually a better
way to do this along your idea, under the current framework.
I've got a piece of code sitting in my own workspace that checks for Java
identifiers:
```scala
object CodegenContext {
private val javaKeywords = Set(
"abstract", "assert", "boolean", "break", "byte", "case", "catch",
"char", "class",
"const", "continue", "default", "do", "double", "else", "extends",
"false", "final",
"finally", "float", "for", "goto", "if", "implements", "import",
"instanceof", "int",
"interface", "long", "native", "new", "null", "package", "private",
"protected", "public",
"return", "short", "static", "strictfp", "super", "switch",
"synchronized", "this",
"throw", "throws", "transient", "true", "try", "void", "volatile",
"while"
)
def isJavaIdentifier(str: String): Boolean = str match {
case null | "" => false
case _ => java.lang.Character.isJavaIdentifierStart(str.charAt(0)) &&
(1 until str.length).forall(
i =>
java.lang.Character.isJavaIdentifierPart(str.charAt(i))) &&
!javaKeywords.contains(str)
}
}
```
Feel free to use it here if you'd like. This is the way
`java.lang.Character.isJavaIdentifierStart()` and
`java.lang.Character.isJavaIdentifierPart()` is supposed to be used anyway,
nothing creative.
If you want to use it in a `case` like the way you're using the regular
expression, just wrap the util above into an `unapply()`. But I'd say simply
making `def isVariable(nameId: String) =
CodegenContext.isJavaIdentifier(nameId)` is clean enough.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136500779
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
+
+ // Returns true if a given name id belongs to this `CodegenContext`
+ private def isVariable(nameId: String): Boolean = nameId match {
+case variableName() => true
+case _ => false
+ }
+
+ // Extracts all the outer references for a given `aggExpr`. This result
will be used to split
+ // aggregation into small functions.
+ private def getOuterReferences(
--- End diff --
`OuterReference` actually has special meaning in correlated subquery. This
name can be confusing.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136490017
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
+
+ // Returns true if a given name id belongs to this `CodegenContext`
+ private def isVariable(nameId: String): Boolean = nameId match {
+case variableName() => true
+case _ => false
+ }
+
+ // Extracts all the outer references for a given `aggExpr`. This result
will be used to split
+ // aggregation into small functions.
+ private def getOuterReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+val stack = mutable.Stack[Expression](aggExpr)
+val argSet = mutable.Set[(String, String)]()
+val addIfNotLiteral = (value: String, tpe: String) => {
+ if (isVariable(value)) {
+argSet += ((tpe, value))
+ }
+}
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType))
+ addIfNotLiteral(exprCode.isNull, "boolean")
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val argVal = ctx.currentVars(ref.ordinal).value
+ addIfNotLiteral(argVal, ctx.javaType(ref.dataType))
+ addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean")
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in Java methods is 255, so this
method gives up splitting
+ // the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
+ // descriptor, where the limit includes one unit for this in the
case of instance
+ // or interface method invocations.
+ val args = (getOuterReferences(ctx, aggExpr, subExprs) ++
otherArgs).toSeq
+
+ // This is for testing/benchmarking only
+ val maxParamNumInJavaMethod =
+
sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
--- End diff --
This is a test-only option, so I think we need not check that.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136455832
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
+
+ // Returns true if a given name id belongs to this `CodegenContext`
+ private def isVariable(nameId: String): Boolean = nameId match {
+case variableName() => true
+case _ => false
+ }
+
+ // Extracts all the outer references for a given `aggExpr`. This result
will be used to split
+ // aggregation into small functions.
+ private def getOuterReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+val stack = mutable.Stack[Expression](aggExpr)
+val argSet = mutable.Set[(String, String)]()
+val addIfNotLiteral = (value: String, tpe: String) => {
+ if (isVariable(value)) {
+argSet += ((tpe, value))
+ }
+}
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType))
+ addIfNotLiteral(exprCode.isNull, "boolean")
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val argVal = ctx.currentVars(ref.ordinal).value
+ addIfNotLiteral(argVal, ctx.javaType(ref.dataType))
+ addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean")
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in Java methods is 255, so this
method gives up splitting
+ // the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
+ // descriptor, where the limit includes one unit for this in the
case of instance
+ // or interface method invocations.
+ val args = (getOuterReferences(ctx, aggExpr, subExprs) ++
otherArgs).toSeq
+
+ // This is for testing/benchmarking only
+ val maxParamNumInJavaMethod =
+
sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
--- End diff --
Can we add a check code if a user specify a value that is more than 255?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136207629
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
+
+ // Returns true if a given name id belongs to this `CodegenContext`
+ private def isVariable(nameId: String): Boolean = nameId match {
+case variableName() => true
+case _ => false
+ }
+
+ // Extracts all the outer references for a given `aggExpr`. This result
will be used to split
+ // aggregation into small functions.
+ private def getOuterReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+val stack = mutable.Stack[Expression](aggExpr)
+val argSet = mutable.Set[(String, String)]()
+val addIfNotLiteral = (value: String, tpe: String) => {
+ if (isVariable(value)) {
+argSet += ((tpe, value))
+ }
+}
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType))
+ addIfNotLiteral(exprCode.isNull, "boolean")
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val argVal = ctx.currentVars(ref.ordinal).value
+ addIfNotLiteral(argVal, ctx.javaType(ref.dataType))
+ addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean")
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in Java methods is 255, so this
method gives up splitting
+ // the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
+ // descriptor, where the limit includes one unit for this in the
case of instance
+ // or interface method invocations.
+ val args = (getOuterReferences(ctx, aggExpr, subExprs) ++
otherArgs).toSeq
+
+ // This is for testing/benchmarking only
+ val maxParamNumInJavaMethod =
+
sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
+case null | "" => 256
--- End diff --
oh, good catch! I'll fix
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19082#discussion_r136090141
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
@@ -244,6 +246,92 @@ case class HashAggregateExec(
protected override val shouldStopRequired = false
+ // We assume a prefix has lower cases and a name has camel cases
+ private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r
+
+ // Returns true if a given name id belongs to this `CodegenContext`
+ private def isVariable(nameId: String): Boolean = nameId match {
+case variableName() => true
+case _ => false
+ }
+
+ // Extracts all the outer references for a given `aggExpr`. This result
will be used to split
+ // aggregation into small functions.
+ private def getOuterReferences(
+ ctx: CodegenContext,
+ aggExpr: Expression,
+ subExprs: Map[Expression, SubExprEliminationState]): Set[(String,
String)] = {
+val stack = mutable.Stack[Expression](aggExpr)
+val argSet = mutable.Set[(String, String)]()
+val addIfNotLiteral = (value: String, tpe: String) => {
+ if (isVariable(value)) {
+argSet += ((tpe, value))
+ }
+}
+while (stack.nonEmpty) {
+ stack.pop() match {
+case e if subExprs.contains(e) =>
+ val exprCode = subExprs(e)
+ addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType))
+ addIfNotLiteral(exprCode.isNull, "boolean")
+ // Since the children possibly has common expressions, we push
them here
+ stack.pushAll(e.children)
+case ref: BoundReference
+if ctx.currentVars != null && ctx.currentVars(ref.ordinal) !=
null =>
+ val argVal = ctx.currentVars(ref.ordinal).value
+ addIfNotLiteral(argVal, ctx.javaType(ref.dataType))
+ addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean")
+case _: BoundReference =>
+ argSet += (("InternalRow", ctx.INPUT_ROW))
+case e =>
+ stack.pushAll(e.children)
+ }
+}
+
+argSet.toSet
+ }
+
+ // Splits the aggregation into small functions because the HotSpot does
not compile
+ // too long functions.
+ private def splitAggregateExpressions(
+ ctx: CodegenContext,
+ aggExprs: Seq[Expression],
+ evalAndUpdateCodes: Seq[String],
+ subExprs: Map[Expression, SubExprEliminationState],
+ otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
+aggExprs.zipWithIndex.map { case (aggExpr, i) =>
+ // The maximum number of parameters in Java methods is 255, so this
method gives up splitting
+ // the code if the number goes over the limit.
+ // You can find more information about the limit in the JVM
specification:
+ // - The number of method parameters is limited to 255 by the
definition of a method
+ // descriptor, where the limit includes one unit for this in the
case of instance
+ // or interface method invocations.
+ val args = (getOuterReferences(ctx, aggExpr, subExprs) ++
otherArgs).toSeq
+
+ // This is for testing/benchmarking only
+ val maxParamNumInJavaMethod =
+
sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
+case null | "" => 256
--- End diff --
Since `$doAggVal` is [non-static
method](https://stackoverflow.com/questions/30581531/maximum-number-of-parameters-in-java-method-declaration),
this number should be `255`.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
-
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
GitHub user maropu opened a pull request:
https://github.com/apache/spark/pull/19082
[SPARK-21870][SQL] Split aggregation code into small functions for the
HotSpot
## What changes were proposed in this pull request?
This pr proposes to split aggregation code into pieces in
`HashAggregateExec` for the JVM HotSpot.
In #18810, we got performance regression if the HotSpot didn't compile too
long functions (the limit is 8 in bytecode size). I checked and I found the
codegen of `HashAggregateExec` frequently goes over the limit, for example:
```
spark.range(1000).selectExpr("id % 1024 AS a", "id AS
b").write.saveAsTable("t")
sql("SELECT a, KURTOSIS(b)FROM t GROUP BY a")
```
This query goes over the limit and the actual bytecode size is `12356`.
This pr split the aggregation code into small separate functions and, in a
simple example;
```
sql("SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)").debugCodegen
```
- generated code with this pr:
```
/* 083 */ private void agg_doAggregateWithoutKey() throws
java.io.IOException {
/* 084 */ // initialize aggregation buffer
/* 085 */ final long agg_value = -1L;
/* 086 */ agg_bufIsNull = true;
/* 087 */ agg_bufValue = agg_value;
/* 088 */ boolean agg_isNull1 = false;
/* 089 */ double agg_value1 = -1.0;
/* 090 */ if (!false) {
/* 091 */ agg_value1 = (double) 0;
/* 092 */ }
/* 093 */ agg_bufIsNull1 = agg_isNull1;
/* 094 */ agg_bufValue1 = agg_value1;
/* 095 */ agg_bufIsNull2 = false;
/* 096 */ agg_bufValue2 = 0L;
/* 097 */
/* 098 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 099 */ InternalRow inputadapter_row = (InternalRow)
inputadapter_input.next();
/* 100 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 101 */ long inputadapter_value = inputadapter_isNull ? -1L :
(inputadapter_row.getLong(0));
/* 102 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 103 */ double inputadapter_value1 = inputadapter_isNull1 ? -1.0 :
(inputadapter_row.getDouble(1));
/* 104 */ boolean inputadapter_isNull2 = inputadapter_row.isNullAt(2);
/* 105 */ long inputadapter_value2 = inputadapter_isNull2 ? -1L :
(inputadapter_row.getLong(2));
/* 106 */
/* 107 */ // do aggregate
/* 108 */ // copy aggregation buffer to the local
/* 109 */ boolean agg_localBufIsNull = agg_bufIsNull;
/* 110 */ long agg_localBufValue = agg_bufValue;
/* 111 */ boolean agg_localBufIsNull1 = agg_bufIsNull1;
/* 112 */ double agg_localBufValue1 = agg_bufValue1;
/* 113 */ boolean agg_localBufIsNull2 = agg_bufIsNull2;
/* 114 */ long agg_localBufValue2 = agg_bufValue2;
/* 115 */ // common sub-expressions
/* 116 */
/* 117 */ // process aggregate functions to update aggregation buffer
/* 118 */ agg_doAggregateVal_coalesce(agg_localBufIsNull,
agg_localBufValue, inputadapter_value, inputadapter_isNull);
/* 119 */ agg_doAggregateVal_add(agg_localBufValue1,
inputadapter_isNull1, inputadapter_value1, agg_localBufIsNull1);
/* 120 */ agg_doAggregateVal_add1(inputadapter_isNull2,
inputadapter_value2, agg_localBufIsNull2, agg_localBufValue2);
/* 121 */ if (shouldStop()) return;
/* 122 */ }
```
- generated code in the current master
```
/* 083 */ private void agg_doAggregateWithoutKey() throws
java.io.IOException {
/* 084 */ // initialize aggregation buffer
/* 085 */ final long agg_value = -1L;
/* 086 */ agg_bufIsNull = true;
/* 087 */ agg_bufValue = agg_value;
/* 088 */ boolean agg_isNull1 = false;
/* 089 */ double agg_value1 = -1.0;
/* 090 */ if (!false) {
/* 091 */ agg_value1 = (double) 0;
/* 092 */ }
/* 093 */ agg_bufIsNull1 = agg_isNull1;
/* 094 */ agg_bufValue1 = agg_value1;
/* 095 */ agg_bufIsNull2 = false;
/* 096 */ agg_bufValue2 = 0L;
/* 097 */
/* 098 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 099 */ InternalRow inputadapter_row = (InternalRow)
inputadapter_input.next();
/* 100 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 101 */ long inputadapter_value = inputadapter_isNull ? -1L :
(inputadapter_row.getLong(0));
/* 102 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 103 */ double inputadapter_value1 = inputadapter_isNull1 ? -1.0 :
(inputadapter_row.getDouble(1));
/* 104 */ boolean inputadapter_isNull2 = inputadapter_row.isNullAt(2);
/* 105 */ long inputadapter_value2 = inputadapter_isNull2 ? -1L :
(inputadapter_row.getLong(2));
/* 106 */
/* 107 */ // do aggreg
