Repository: spark Updated Branches: refs/heads/master 813c0f945 -> e98f9647f
[SPARK-22695][SQL] ScalaUDF should not use global variables ## What changes were proposed in this pull request? ScalaUDF is using global variables which are not needed. This can generate some unneeded entries in the constant pool. The PR replaces the unneeded global variables with local variables. ## How was this patch tested? added UT Author: Marco Gaido <mga...@hortonworks.com> Author: Marco Gaido <marcogaid...@gmail.com> Closes #19900 from mgaido91/SPARK-22695. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e98f9647 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e98f9647 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e98f9647 Branch: refs/heads/master Commit: e98f9647f44d1071a6b070db070841b8cda6bd7a Parents: 813c0f9 Author: Marco Gaido <mga...@hortonworks.com> Authored: Thu Dec 7 00:50:49 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Thu Dec 7 00:50:49 2017 +0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/ScalaUDF.scala | 88 ++++++++++---------- .../catalyst/expressions/ScalaUDFSuite.scala | 6 ++ 2 files changed, 51 insertions(+), 43 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e98f9647/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 1798530..4d26d98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -982,35 +982,28 @@ case class ScalaUDF( // scalastyle:on line.size.limit - // Generate codes used to convert the arguments to Scala type for user-defined functions - private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - val expressionClassName = classOf[Expression].getName - val scalaUDFClassName = classOf[ScalaUDF].getName + private val converterClassName = classOf[Any => Any].getName + private val scalaUDFClassName = classOf[ScalaUDF].getName + private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + // Generate codes used to convert the arguments to Scala type for user-defined functions + private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String) = { val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 - ctx.addMutableState(converterClassName, converterTerm, - s"$converterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + - s"references[$expressionIdx]).getChildren().apply($index))).dataType());") - converterTerm + (converterTerm, + s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter(((Expression)((($scalaUDFClassName)" + + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") } override def doGenCode( ctx: CodegenContext, ev: ExprCode): ExprCode = { + val scalaUDF = ctx.freshName("scalaUDF") + val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName) - val scalaUDF = ctx.addReferenceObj("scalaUDF", this) - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - - // Generate codes used to convert the returned value of user-defined functions to Catalyst type + // Object to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - ctx.addMutableState(converterClassName, catalystConverterTerm, - s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1022,8 +1015,6 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - ctx.addMutableState(funcClassName, funcTerm, - s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1033,34 +1024,45 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => - val eval = evals(i) - val argTerm = ctx.freshName("arg") - val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" - (convert, argTerm) + val (converters, funcArguments) = converterTerms.zipWithIndex.map { + case ((convName, convInit), i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = + s""" + |$convInit + |Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value}); + """.stripMargin + (convert, argTerm) }.unzip val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" val callFunc = s""" - ${ctx.boxedType(dataType)} $resultTerm = null; - try { - $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); - } catch (Exception e) { - throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); - } - """ + |${ctx.boxedType(dataType)} $resultTerm = null; + |$scalaUDFClassName $scalaUDF = $scalaUDFRef; + |try { + | $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc(); + | $converterClassName $catalystConverterTerm = ($converterClassName) + | $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType()); + | $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + |} catch (Exception e) { + | throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + |} + """.stripMargin - ev.copy(code = s""" - $evalCode - ${converters.mkString("\n")} - $callFunc - - boolean ${ev.isNull} = $resultTerm == null; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $resultTerm; - }""") + ev.copy(code = + s""" + |$evalCode + |${converters.mkString("\n")} + |$callFunc + | + |boolean ${ev.isNull} = $resultTerm == null; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + """.stripMargin) } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) http://git-wip-us.apache.org/repos/asf/spark/blob/e98f9647/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 13bd363..70dea4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types.{IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -47,4 +48,9 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e2.getMessage.contains("Failed to execute user defined function")) } + test("SPARK-22695: ScalaUDF should not use global variables") { + val ctx = new CodegenContext + ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org