Repository: spark
Updated Branches:
  refs/heads/branch-1.6 ad2ebe4db -> 285792b6c


[SPARK-11725][SQL] correctly handle null inputs for UDF

If user use primitive parameters in UDF, there is no way for him to do the 
null-check for primitive inputs, so we are assuming the primitive input is 
null-propagatable for this case and return null if the input is null.

Author: Wenchen Fan <[email protected]>

Closes #9770 from cloud-fan/udf.

(cherry picked from commit 33b837333435ceb0c04d1f361a5383c4fe6a5a75)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/285792b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/285792b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/285792b6

Branch: refs/heads/branch-1.6
Commit: 285792b6cb39361f308561f844f164c4fe919f2b
Parents: ad2ebe4
Author: Wenchen Fan <[email protected]>
Authored: Wed Nov 18 10:23:12 2015 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Wed Nov 18 10:23:27 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  9 ++++
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 32 +++++++++++++-
 .../sql/catalyst/expressions/ScalaUDF.scala     |  6 +++
 .../sql/catalyst/ScalaReflectionSuite.scala     | 17 ++++++++
 .../sql/catalyst/analysis/AnalysisSuite.scala   | 44 ++++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   | 14 +++++++
 6 files changed, 121 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/285792b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 0b3dd35..38828e5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -719,6 +719,15 @@ trait ScalaReflection {
     }
   }
 
+  /**
+   * Returns classes of input parameters of scala function object.
+   */
+  def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
+    val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && 
!m.isBridge)
+    assert(methods.length == 1)
+    methods.head.getParameterTypes
+  }
+
   def typeOfObject: PartialFunction[Any, DataType] = {
     // The data type can be determined without ambiguity.
     case obj: Boolean => BooleanType

http://git-wip-us.apache.org/repos/asf/spark/blob/285792b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2f4670b..f00c451 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.trees.TreeNodeRef
-import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
+import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, 
CatalystConf}
 import org.apache.spark.sql.types._
 
 /**
@@ -85,6 +85,8 @@ class Analyzer(
       extendedResolutionRules : _*),
     Batch("Nondeterministic", Once,
       PullOutNondeterministic),
+    Batch("UDF", Once,
+      HandleNullInputsForUDF),
     Batch("Cleanup", fixedPoint,
       CleanupAliases)
   )
@@ -1063,6 +1065,34 @@ class Analyzer(
         Project(p.output, newPlan.withNewChildren(newChild :: Nil))
     }
   }
+
+  /**
+   * Correctly handle null primitive inputs for UDF by adding extra [[If]] 
expression to do the
+   * null check.  When user defines a UDF with primitive parameters, there is 
no way to tell if the
+   * primitive parameter is null or not, so here we assume the primitive input 
is null-propagatable
+   * and we should return null if the input is null.
+   */
+  object HandleNullInputsForUDF extends Rule[LogicalPlan] {
+    override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators 
{
+      case p if !p.resolved => p // Skip unresolved nodes.
+
+      case plan => plan transformExpressionsUp {
+
+        case udf @ ScalaUDF(func, _, inputs, _) =>
+          val parameterTypes = ScalaReflection.getParameterTypes(func)
+          assert(parameterTypes.length == inputs.length)
+
+          val inputsNullCheck = parameterTypes.zip(inputs)
+            // TODO: skip null handling for not-nullable primitive inputs 
after we can completely
+            // trust the `nullable` information.
+            // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
+            .filter { case (cls, _) => cls.isPrimitive }
+            .map { case (_, expr) => IsNull(expr) }
+            .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
+          inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), 
udf)).getOrElse(udf)
+      }
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/285792b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 3388cc2..03b8922 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType
 
 /**
  * User-defined function.
+ * @param function  The user defined scala function to run.
+ *                  Note that if you use primitive parameters, you are not 
able to check if it is
+ *                  null or not, and the UDF will return null for you if the 
primitive input is
+ *                  null. Use boxed type or [[Option]] if you wanna do the 
null-handling yourself.
  * @param dataType  Return type of function.
+ * @param children  The input expressions of this UDF.
+ * @param inputTypes  The expected input types of this UDF.
  */
 case class ScalaUDF(
     function: AnyRef,

http://git-wip-us.apache.org/repos/asf/spark/blob/285792b6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 3b848cf..4ea410d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite {
         assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, 
DoubleType))
     }
   }
+
+  test("get parameter type from a function object") {
+    val primitiveFunc = (i: Int, j: Long) => "x"
+    val primitiveTypes = getParameterTypes(primitiveFunc)
+    assert(primitiveTypes.forall(_.isPrimitive))
+    assert(primitiveTypes === Seq(classOf[Int], classOf[Long]))
+
+    val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x"
+    val boxedTypes = getParameterTypes(boxedFunc)
+    assert(boxedTypes.forall(!_.isPrimitive))
+    assert(boxedTypes === Seq(classOf[java.lang.Integer], 
classOf[java.lang.Long]))
+
+    val anyFunc = (i: Any, j: AnyRef) => "x"
+    val anyTypes = getParameterTypes(anyFunc)
+    assert(anyTypes.forall(!_.isPrimitive))
+    assert(anyTypes === Seq(classOf[java.lang.Object], 
classOf[java.lang.Object]))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/285792b6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 65f09b4..08586a9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest {
     )
     assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same 
type"))
   }
+
+  test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
+    val string = testRelation2.output(0)
+    val double = testRelation2.output(2)
+    val short = testRelation2.output(4)
+    val nullResult = Literal.create(null, StringType)
+
+    def checkUDF(udf: Expression, transformed: Expression): Unit = {
+      checkAnalysis(
+        Project(Alias(udf, "")() :: Nil, testRelation2),
+        Project(Alias(transformed, "")() :: Nil, testRelation2)
+      )
+    }
+
+    // non-primitive parameters do not need special null handling
+    val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
+    val expected1 = udf1
+    checkUDF(udf1, expected1)
+
+    // only primitive parameter needs special null handling
+    val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: 
double :: Nil)
+    val expected2 = If(IsNull(double), nullResult, udf2)
+    checkUDF(udf2, expected2)
+
+    // special null handling should apply to all primitive parameters
+    val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: 
double :: Nil)
+    val expected3 = If(
+      IsNull(short) || IsNull(double),
+      nullResult,
+      udf3)
+    checkUDF(udf3, expected3)
+
+    // we can skip special null handling for primitive parameters that are not 
nullable
+    // TODO: this is disabled for now as we can not completely trust 
`nullable`.
+    val udf4 = ScalaUDF(
+      (s: Short, d: Double) => "x",
+      StringType,
+      short :: double.withNullability(false) :: Nil)
+    val expected4 = If(
+      IsNull(short),
+      nullResult,
+      udf4)
+    // checkUDF(udf4, expected4)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/285792b6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 35cdab5..5a7f246 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(df.select(df("*")), Row(1, "a"))
     checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a"))
   }
+
+  test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
+    val df = Seq(
+      new java.lang.Integer(22) -> "John",
+      null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name")
+
+    val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
+      (i: java.lang.Integer) => if (i == null) null else i * 2
+    }
+    checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil)
+
+    val primitiveUDF = udf((i: Int) => i * 2)
+    checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to