[ 
https://issues.apache.org/jira/browse/SPARK-23833?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Hyukjin Kwon resolved SPARK-23833.
----------------------------------
    Resolution: Incomplete

> Incorrect primitive type check for input arguments of udf
> ---------------------------------------------------------
>
>                 Key: SPARK-23833
>                 URL: https://issues.apache.org/jira/browse/SPARK-23833
>             Project: Spark
>          Issue Type: Bug
>          Components: Optimizer
>    Affects Versions: 2.2.0, 2.3.0
>            Reporter: Valentin Nikotin
>            Priority: Major
>              Labels: bulk-closed
>
> There is claimed behavior for scala UDFs with primitive type arguments:
> {quote}Note that if you use primitive parameters, you are not able to check 
> if it is null or not, and the UDF will return null for you if the primitive 
> input is null.
> {quote}
> This is initial issue - SPARK-11725
>  Correspondent pr - 
> [PR|https://github.com/apache/spark/pull/9770/commits/a8a30674ce531c9cd10107200a3f72f9539cd8f6]
> The problem is that {{ScalaReflection.getParameterTypes}} doesn't work 
> correctly due to type erasure. 
> The correct check "if type is primitive" should be based on typeTag something 
> like this:
> {code:java}
> typeTag[T].tpe.typeSymbol.asClass.isPrimitive
> {code}
>  
> The problem appears if we have high order functions:
> {code:java}
> val f = (x: Long) => x
> def identity[T, U](f: T => U): T => U = (t: T) => f(t)
> val udf0 = udf(f)
> val udf1 = udf(identity(f))
> val getNull = udf(() => null.asInstanceOf[java.lang.Long])
> spark.range(5).toDF().
>   withColumn("udf0", udf0(getNull())).
>   withColumn("udf1", udf1(getNull())).
>   show()
> spark.range(5).toDF().
>   withColumn("udf0", udf0(getNull())).
>   withColumn("udf1", udf1(getNull())).
>   explain()
> {code}
> Test execution on Spark 2.2 spark-shell:
> {code:java}
> scala> val f = (x: Long) => x
> f: Long => Long = <function1>
> scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t)
> identity: [T, U](f: T => U)T => U
> scala> val udf0 = udf(f)
> udf0: org.apache.spark.sql.expressions.UserDefinedFunction = 
> UserDefinedFunction(<function1>,LongType,Some(List(LongType)))
> scala> val udf1 = udf(identity(f))
> udf1: org.apache.spark.sql.expressions.UserDefinedFunction = 
> UserDefinedFunction(<function1>,LongType,Some(List(LongType)))
> scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long])
> getNull: org.apache.spark.sql.expressions.UserDefinedFunction = 
> UserDefinedFunction(<function0>,LongType,Some(List()))
> scala> spark.range(5).toDF().
>      |   withColumn("udf0", udf0(getNull())).
>      |   withColumn("udf1", udf1(getNull())).
>      |   show()
> +---+----+----+                                                               
>   
> | id|udf0|udf1|
> +---+----+----+
> |  0|null|   0|
> |  1|null|   0|
> |  2|null|   0|
> |  3|null|   0|
> |  4|null|   0|
> +---+----+----+
> scala> spark.range(5).toDF().
>      |   withColumn("udf0", udf0(getNull())).
>      |   withColumn("udf1", udf1(getNull())).
>      |   explain()
> == Physical Plan ==
> *Project [id#19L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#24L, 
> UDF(UDF()) AS udf1#28L]
> +- *Range (0, 5, step=1, splits=6)
> {code}
>  
> The typeTag information about input parameters is available in udf function 
> but only used to get schema, it should be added to ScalaUDF too so that we 
> can used it later:
> {code:java}
> def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): 
> UserDefinedFunction = {
>   val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
> ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption
>   UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
> inputTypes)
> }
> {code}
>  
> Here is current vs desired version:
> {code:java}
> scala> import org.apache.spark.sql.catalyst.ScalaReflection
> import org.apache.spark.sql.catalyst.ScalaReflection
> scala> ScalaReflection.getParameterTypes(identity(f))
> res2: Seq[Class[_]] = WrappedArray(class java.lang.Object)
> scala> ScalaReflection.getParameterTypes(identity(f)).map(_.isPrimitive)
> res7: Seq[Boolean] = ArrayBuffer(false)
> {code}
> versus
> {code:java}
> scala> import scala.reflect.runtime.universe.{typeTag, TypeTag}
> import scala.reflect.runtime.universe.{typeTag, TypeTag}
> scala> def myGetParameterTypes[T : TypeTag, U](func: T => U) = {
>      |   typeTag[T].tpe.typeSymbol.asClass
>      | }
> myGetParameterTypes: [T, U](func: T => U)(implicit evidence$1: 
> reflect.runtime.universe.TypeTag[T])reflect.runtime.universe.ClassSymbol
> scala> myGetParameterTypes(f)
> res3: reflect.runtime.universe.ClassSymbol = class Long
> scala> myGetParameterTypes(f).isPrimitive
> res4: Boolean = true
> {code}
> Although for this case there is workaround with using {{@specialized(Long)}}
> {code:scala}
> scala> def identity2[@specialized(Long) T, U](f: T => U): T => U = (t: T) => 
> f(t)
> identity2: [T, U](f: T => U)T => U
> scala> ScalaReflection.getParameterTypes(identity2(f))
> res10: Seq[Class[_]] = WrappedArray(long)
> {code}



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to