maryannxue commented on a change in pull request #23712: [SPARK-26798][SQL] 
HandleNullInputsForUDF should trust nullability
URL: https://github.com/apache/spark/pull/23712#discussion_r255381444
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 ##########
 @@ -2147,27 +2147,36 @@ class Analyzer(
 
       case p => p transformExpressionsUp {
 
-        case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _)
-            if inputsNullSafe.contains(false) =>
+        case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _)
+            if inputPrimitives.contains(true) =>
           // Otherwise, add special handling of null for fields that can't 
accept null.
           // The result of operations like this, when passed null, is 
generally to return null.
-          assert(inputsNullSafe.length == inputs.length)
-
-          // TODO: skip null handling for not-nullable primitive inputs after 
we can completely
-          // trust the `nullable` information.
-          val inputsNullCheck = inputsNullSafe.zip(inputs)
-            .filter { case (nullSafe, _) => !nullSafe }
-            .map { case (_, expr) => IsNull(expr) }
-            .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
-          // Once we add an `If` check above the udf, it is safe to mark those 
checked inputs
-          // as null-safe (i.e., set `inputsNullSafe` all `true`), because the 
null-returning
-          // branch of `If` will be called if any of these checked inputs is 
null. Thus we can
-          // prevent this rule from being applied repeatedly.
-          val newInputsNullSafe = inputsNullSafe.map(_ => true)
-          inputsNullCheck
-            .map(If(_, Literal.create(null, udf.dataType),
-              udf.copy(inputsNullSafe = newInputsNullSafe)))
-            .getOrElse(udf)
+          assert(inputPrimitives.length == inputs.length)
+
+          val inputPrimitivesPair = inputPrimitives.zip(inputs)
+          val inputNullCheck = inputPrimitivesPair.collect {
+            case (isPrimitive, input) if isPrimitive && input.nullable =>
+              IsNull(input)
+          }.reduceLeftOption[Expression](Or)
+
+          if (inputNullCheck.isDefined) {
+            // Once we add an `If` check above the udf, it is safe to mark 
those checked inputs
+            // as null-safe (i.e., wrap with `KnownNotNull`), because the 
null-returning
+            // branch of `If` will be called if any of these checked inputs is 
null. Thus we can
+            // prevent this rule from being applied repeatedly.
+            val newInputs = inputPrimitivesPair.map {
+              case (isPrimitive, input) =>
+                if (isPrimitive && input.nullable) {
 
 Review comment:
   This is just to confirm. And I think this is the write way to go.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to