[ https://issues.apache.org/jira/browse/SPARK-23833?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Hyukjin Kwon resolved SPARK-23833. ---------------------------------- Resolution: Incomplete > Incorrect primitive type check for input arguments of udf > --------------------------------------------------------- > > Key: SPARK-23833 > URL: https://issues.apache.org/jira/browse/SPARK-23833 > Project: Spark > Issue Type: Bug > Components: Optimizer > Affects Versions: 2.2.0, 2.3.0 > Reporter: Valentin Nikotin > Priority: Major > Labels: bulk-closed > > There is claimed behavior for scala UDFs with primitive type arguments: > {quote}Note that if you use primitive parameters, you are not able to check > if it is null or not, and the UDF will return null for you if the primitive > input is null. > {quote} > This is initial issue - SPARK-11725 > Correspondent pr - > [PR|https://github.com/apache/spark/pull/9770/commits/a8a30674ce531c9cd10107200a3f72f9539cd8f6] > The problem is that {{ScalaReflection.getParameterTypes}} doesn't work > correctly due to type erasure. > The correct check "if type is primitive" should be based on typeTag something > like this: > {code:java} > typeTag[T].tpe.typeSymbol.asClass.isPrimitive > {code} > > The problem appears if we have high order functions: > {code:java} > val f = (x: Long) => x > def identity[T, U](f: T => U): T => U = (t: T) => f(t) > val udf0 = udf(f) > val udf1 = udf(identity(f)) > val getNull = udf(() => null.asInstanceOf[java.lang.Long]) > spark.range(5).toDF(). > withColumn("udf0", udf0(getNull())). > withColumn("udf1", udf1(getNull())). > show() > spark.range(5).toDF(). > withColumn("udf0", udf0(getNull())). > withColumn("udf1", udf1(getNull())). > explain() > {code} > Test execution on Spark 2.2 spark-shell: > {code:java} > scala> val f = (x: Long) => x > f: Long => Long = <function1> > scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t) > identity: [T, U](f: T => U)T => U > scala> val udf0 = udf(f) > udf0: org.apache.spark.sql.expressions.UserDefinedFunction = > UserDefinedFunction(<function1>,LongType,Some(List(LongType))) > scala> val udf1 = udf(identity(f)) > udf1: org.apache.spark.sql.expressions.UserDefinedFunction = > UserDefinedFunction(<function1>,LongType,Some(List(LongType))) > scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long]) > getNull: org.apache.spark.sql.expressions.UserDefinedFunction = > UserDefinedFunction(<function0>,LongType,Some(List())) > scala> spark.range(5).toDF(). > | withColumn("udf0", udf0(getNull())). > | withColumn("udf1", udf1(getNull())). > | show() > +---+----+----+ > > | id|udf0|udf1| > +---+----+----+ > | 0|null| 0| > | 1|null| 0| > | 2|null| 0| > | 3|null| 0| > | 4|null| 0| > +---+----+----+ > scala> spark.range(5).toDF(). > | withColumn("udf0", udf0(getNull())). > | withColumn("udf1", udf1(getNull())). > | explain() > == Physical Plan == > *Project [id#19L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#24L, > UDF(UDF()) AS udf1#28L] > +- *Range (0, 5, step=1, splits=6) > {code} > > The typeTag information about input parameters is available in udf function > but only used to get schema, it should be added to ScalaUDF too so that we > can used it later: > {code:java} > def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): > UserDefinedFunction = { > val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: > ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption > UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, > inputTypes) > } > {code} > > Here is current vs desired version: > {code:java} > scala> import org.apache.spark.sql.catalyst.ScalaReflection > import org.apache.spark.sql.catalyst.ScalaReflection > scala> ScalaReflection.getParameterTypes(identity(f)) > res2: Seq[Class[_]] = WrappedArray(class java.lang.Object) > scala> ScalaReflection.getParameterTypes(identity(f)).map(_.isPrimitive) > res7: Seq[Boolean] = ArrayBuffer(false) > {code} > versus > {code:java} > scala> import scala.reflect.runtime.universe.{typeTag, TypeTag} > import scala.reflect.runtime.universe.{typeTag, TypeTag} > scala> def myGetParameterTypes[T : TypeTag, U](func: T => U) = { > | typeTag[T].tpe.typeSymbol.asClass > | } > myGetParameterTypes: [T, U](func: T => U)(implicit evidence$1: > reflect.runtime.universe.TypeTag[T])reflect.runtime.universe.ClassSymbol > scala> myGetParameterTypes(f) > res3: reflect.runtime.universe.ClassSymbol = class Long > scala> myGetParameterTypes(f).isPrimitive > res4: Boolean = true > {code} > Although for this case there is workaround with using {{@specialized(Long)}} > {code:scala} > scala> def identity2[@specialized(Long) T, U](f: T => U): T => U = (t: T) => > f(t) > identity2: [T, U](f: T => U)T => U > scala> ScalaReflection.getParameterTypes(identity2(f)) > res10: Seq[Class[_]] = WrappedArray(long) > {code} -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org