This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 217f1748224 [SPARK-37829][SQL] Dataframe.joinWith outer-join should 
return a null value for unmatched row
217f1748224 is described below

commit 217f1748224b6ade306eb5f0782e0af085378c55
Author: --global <xuqiang...@gmail.com>
AuthorDate: Wed Apr 19 22:05:04 2023 +0800

    [SPARK-37829][SQL] Dataframe.joinWith outer-join should return a null value 
for unmatched row
    
    ### What changes were proposed in this pull request?
    
    When doing an outer join with joinWith on DataFrames, unmatched rows return 
Row objects with null fields instead of a single null value. This is not a 
expected behavior, and it's a regression introduced in [this 
commit](https://github.com/apache/spark/commit/cd92f25be5a221e0d4618925f7bc9dfd3bb8cb59).
    This pull request aims to fix the regression, note this is not a full 
rollback of the commit, do not add back "schema" variable.
    
    ```
    case class ClassData(a: String, b: Int)
    val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF
    val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF
    
    left.joinWith(right, left("b") === right("b"), "left_outer").collect
    ```
    
    ```
    Wrong results (current behavior):    Array(([a,1],[null,null]), 
([b,2],[x,2]))
    Correct results:                     Array(([a,1],null), ([b,2],[x,2]))
    ```
    
    ### Why are the changes needed?
    
    We need to address the regression mentioned above. It results in unexpected 
behavior changes in the Dataframe joinWith API between versions 2.4.8 and 
3.0.0+. This could potentially cause data correctness issues for users who 
expect the old behavior when using Spark 3.0.0+.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit test (use the same test in previous [closed pull 
request](https://github.com/apache/spark/pull/35140), credit to Clément de Groc)
    Run sql-core and sql-catalyst submodules locally with ./build/mvn clean 
package -pl sql/core,sql/catalyst
    
    Closes #40755 from kings129/encoder_bug_fix.
    
    Authored-by: --global <xuqiang...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 74ce620901a958a1ddd76360e2faed7d3a111d4e)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/encoders/ExpressionEncoder.scala  | 19 ++++++---
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 45 ++++++++++++++++++++++
 2 files changed, 58 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index faa165c298d..8f7583c48fc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -97,22 +97,29 @@ object ExpressionEncoder {
     }
     val newSerializer = CreateStruct(serializers)
 
+    def nullSafe(input: Expression, result: Expression): Expression = {
+      If(IsNull(input), Literal.create(null, result.dataType), result)
+    }
+
     val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType)
-    val deserializers = encoders.zipWithIndex.map { case (enc, index) =>
+    val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) 
=>
       val getColExprs = enc.objDeserializer.collect { case c: 
GetColumnByOrdinal => c }.distinct
       assert(getColExprs.size == 1, "object deserializer should have only one 
" +
         s"`GetColumnByOrdinal`, but there are ${getColExprs.size}")
 
       val input = GetStructField(newDeserializerInput, index)
-      enc.objDeserializer.transformUp {
+      val childDeserializer = enc.objDeserializer.transformUp {
         case GetColumnByOrdinal(0, _) => input
       }
-    }
-    val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), 
propagateNull = false)
 
-    def nullSafe(input: Expression, result: Expression): Expression = {
-      If(IsNull(input), Literal.create(null, result.dataType), result)
+      if (enc.objSerializer.nullable) {
+        nullSafe(input, childDeserializer)
+      } else {
+        childDeserializer
+      }
     }
+    val newDeserializer =
+      NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = 
false)
 
     new ExpressionEncoder[Any](
       nullSafe(newSerializerInput, newSerializer),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 86e640a4fa8..f8f6845afca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.internal.config.MAX_RESULT_SIZE
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, 
ScroogeLikeExample}
 import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
 import org.apache.spark.sql.catalyst.util.sideBySide
 import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
@@ -2416,6 +2417,50 @@ class DatasetSuite extends QueryTest
       assert(parquetFiles.size === 10)
     }
   }
+
+  test("SPARK-37829: DataFrame outer join") {
+    // Same as "SPARK-15441: Dataset outer join" but using DataFrames instead 
of Datasets
+    val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left")
+    val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right")
+    val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
+
+    val leftFieldSchema = StructType(
+      Seq(
+        StructField("a", StringType),
+        StructField("b", IntegerType, nullable = false)
+      )
+    )
+    val rightFieldSchema = StructType(
+      Seq(
+        StructField("a", StringType),
+        StructField("b", IntegerType, nullable = false)
+      )
+    )
+    val expectedSchema = StructType(
+      Seq(
+        StructField(
+          "_1",
+          leftFieldSchema,
+          nullable = false
+        ),
+        // This is a left join, so the right output is nullable:
+        StructField(
+          "_2",
+          rightFieldSchema
+        )
+      )
+    )
+    assert(joined.schema === expectedSchema)
+
+    val result = joined.collect().toSet
+    val expected = Set(
+      new GenericRowWithSchema(Array("a", 1), leftFieldSchema) ->
+        null,
+      new GenericRowWithSchema(Array("b", 2), leftFieldSchema) ->
+        new GenericRowWithSchema(Array("x", 2), rightFieldSchema)
+    )
+    assert(result == expected)
+  }
 }
 
 class DatasetLargeResultCollectingSuite extends QueryTest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to