Repository: spark Updated Branches: refs/heads/master 6f41c593b -> 813c0f945
[SPARK-22704][SQL] Least and Greatest use less global variables ## What changes were proposed in this pull request? This PR accomplishes the following two items. 1. Reduce # of global variables from two to one 2. Make lifetime of global variable local within an operation Item 1. reduces # of constant pool entries in a Java class. Item 2. ensures that an variable is not passed to arguments in a method split by `CodegenContext.splitExpressions()`, which is addressed by #19865. ## How was this patch tested? Added new test into `ArithmeticExpressionSuite` Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Closes #19899 from kiszk/SPARK-22704. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/813c0f94 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/813c0f94 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/813c0f94 Branch: refs/heads/master Commit: 813c0f945d7f03800975eaed26b86a1f30e513c9 Parents: 6f41c59 Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Authored: Thu Dec 7 00:45:51 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Thu Dec 7 00:45:51 2017 +0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/arithmetic.scala | 94 +++++++++++++------- .../expressions/ArithmeticExpressionSuite.scala | 11 +++ 2 files changed, 73 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/813c0f94/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 739bd13..1893eec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,23 +602,38 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("leastTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, ev.value, eval.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, ev.value, eval.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "least", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } @@ -668,22 +683,37 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("greatestTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, eval.value, ev.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, eval.value, ev.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "greatest", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } http://git-wip-us.apache.org/repos/asf/spark/blob/813c0f94/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index fb759eb..be638d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -343,4 +344,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow) checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow) } + + test("SPARK-22704: Least and greatest use less global variables") { + val ctx1 = new CodegenContext() + Least(Seq(Literal(1), Literal(1))).genCode(ctx1) + assert(ctx1.mutableStates.size == 1) + + val ctx2 = new CodegenContext() + Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) + assert(ctx2.mutableStates.size == 1) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org