This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b8ea60bd4f2 [SPARK-47927][SQL] Fix nullability attribute in UDF 
decoder
8b8ea60bd4f2 is described below

commit 8b8ea60bd4f22ea5763a77bac2d51f25d2479be9
Author: Emil Ejbyfeldt <eejbyfe...@liveintent.com>
AuthorDate: Sun Apr 28 13:46:03 2024 +0800

    [SPARK-47927][SQL] Fix nullability attribute in UDF decoder
    
    ### What changes were proposed in this pull request?
    
    This PR fixes a correctness issue by moving the batch that resolves udf 
decoders to after the `UpdateNullability` batch. This means we now derive a  
decoder with the updated attributes which fixes a correctness issue.
    
    I think the issue has existed since 
https://github.com/apache/spark/pull/28645 when udf support case class 
arguments was added. So therefore this issue should be present in all currently 
supported versions.
    
    ### Why are the changes needed?
    
    Currently the following code
    ```
    scala> val ds1 = Seq(1).toDS()
         | val ds2 = Seq[Int]().toDS()
         | val f = udf[Tuple1[Option[Int]],Tuple1[Option[Int]]](identity)
         | ds1.join(ds2, ds1("value") === ds2("value"), 
"left_outer").select(f(struct(ds2("value")))).collect()
    val ds1: org.apache.spark.sql.Dataset[Int] = [value: int]
    val ds2: org.apache.spark.sql.Dataset[Int] = [value: int]
    val f: org.apache.spark.sql.expressions.UserDefinedFunction = 
SparkUserDefinedFunction($Lambda$2481/0x00007f7f50961f086b1a2c9f,StructType(StructField(_1,IntegerType,true)),List(Some(class[_1[0]:
 int])),Some(class[_1[0]: int]),None,true,true)
    val res0: Array[org.apache.spark.sql.Row] = Array([[0]])
    ```
    results in an row containing `0` this is incorrect as the value should be 
`null`. Removing the udf call
    ```
    scala> ds1.join(ds2, ds1("value") === ds2("value"), 
"left_outer").select(struct(ds2("value"))).collect()
    val res1: Array[org.apache.spark.sql.Row] = Array([[null]])
    ```
    gives the correct value.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, fixes a correctness issue when using ScalaUDFs.
    
    ### How was this patch tested?
    
    Existing and new unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46156 from eejbyfeldt/SPARK-47927.
    
    Authored-by: Emil Ejbyfeldt <eejbyfe...@liveintent.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala     |  4 ++--
 sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala   | 11 +++++++++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e666200a78d4..4b753e1f28e5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -339,11 +339,11 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       new ResolveHints.RemoveAllHints),
     Batch("Nondeterministic", Once,
       PullOutNondeterministic),
+    Batch("UpdateNullability", Once,
+      UpdateAttributeNullability),
     Batch("UDF", Once,
       HandleNullInputsForUDF,
       ResolveEncodersInUDF),
-    Batch("UpdateNullability", Once,
-      UpdateAttributeNullability),
     Batch("Subquery", Once,
       UpdateOuterReferences),
     Batch("Cleanup", fixedPoint,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 87ca3a07c4d5..fe47d6c68555 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -1183,4 +1183,15 @@ class UDFSuite extends QueryTest with SharedSparkSession 
{
       df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => 
reverseThenConcat2(b1, b2)))
     checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
   }
+
+  test("SPARK-47927: Correctly pass null values derived from join to UDF") {
+    val f = udf[Tuple1[Option[Int]], Tuple1[Option[Int]]](identity)
+    val ds1 = Seq(1).toDS()
+    val ds2 = Seq[Int]().toDS()
+
+    checkAnswer(
+      ds1.join(ds2, ds1("value") === ds2("value"), "left_outer")
+        .select(f(struct(ds2("value").as("_1")))),
+      Row(Row(null)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to