This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 217f1748224 [SPARK-37829][SQL] Dataframe.joinWith outer-join should return a null value for unmatched row 217f1748224 is described below commit 217f1748224b6ade306eb5f0782e0af085378c55 Author: --global <xuqiang...@gmail.com> AuthorDate: Wed Apr 19 22:05:04 2023 +0800 [SPARK-37829][SQL] Dataframe.joinWith outer-join should return a null value for unmatched row ### What changes were proposed in this pull request? When doing an outer join with joinWith on DataFrames, unmatched rows return Row objects with null fields instead of a single null value. This is not a expected behavior, and it's a regression introduced in [this commit](https://github.com/apache/spark/commit/cd92f25be5a221e0d4618925f7bc9dfd3bb8cb59). This pull request aims to fix the regression, note this is not a full rollback of the commit, do not add back "schema" variable. ``` case class ClassData(a: String, b: Int) val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF left.joinWith(right, left("b") === right("b"), "left_outer").collect ``` ``` Wrong results (current behavior): Array(([a,1],[null,null]), ([b,2],[x,2])) Correct results: Array(([a,1],null), ([b,2],[x,2])) ``` ### Why are the changes needed? We need to address the regression mentioned above. It results in unexpected behavior changes in the Dataframe joinWith API between versions 2.4.8 and 3.0.0+. This could potentially cause data correctness issues for users who expect the old behavior when using Spark 3.0.0+. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit test (use the same test in previous [closed pull request](https://github.com/apache/spark/pull/35140), credit to Clément de Groc) Run sql-core and sql-catalyst submodules locally with ./build/mvn clean package -pl sql/core,sql/catalyst Closes #40755 from kings129/encoder_bug_fix. Authored-by: --global <xuqiang...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 74ce620901a958a1ddd76360e2faed7d3a111d4e) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 19 ++++++--- .../scala/org/apache/spark/sql/DatasetSuite.scala | 45 ++++++++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index faa165c298d..8f7583c48fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -97,22 +97,29 @@ object ExpressionEncoder { } val newSerializer = CreateStruct(serializers) + def nullSafe(input: Expression, result: Expression): Expression = { + If(IsNull(input), Literal.create(null, result.dataType), result) + } + val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) - val deserializers = encoders.zipWithIndex.map { case (enc, index) => + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct assert(getColExprs.size == 1, "object deserializer should have only one " + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") val input = GetStructField(newDeserializerInput, index) - enc.objDeserializer.transformUp { + val childDeserializer = enc.objDeserializer.transformUp { case GetColumnByOrdinal(0, _) => input } - } - val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false) - def nullSafe(input: Expression, result: Expression): Expression = { - If(IsNull(input), Literal.create(null, result.dataType), result) + if (enc.objSerializer.nullable) { + nullSafe(input, childDeserializer) + } else { + childDeserializer + } } + val newDeserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( nullSafe(newSerializerInput, newSerializer), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 86e640a4fa8..f8f6845afca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} @@ -2416,6 +2417,50 @@ class DatasetSuite extends QueryTest assert(parquetFiles.size === 10) } } + + test("SPARK-37829: DataFrame outer join") { + // Same as "SPARK-15441: Dataset outer join" but using DataFrames instead of Datasets + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + + val leftFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val rightFieldSchema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ) + ) + val expectedSchema = StructType( + Seq( + StructField( + "_1", + leftFieldSchema, + nullable = false + ), + // This is a left join, so the right output is nullable: + StructField( + "_2", + rightFieldSchema + ) + ) + ) + assert(joined.schema === expectedSchema) + + val result = joined.collect().toSet + val expected = Set( + new GenericRowWithSchema(Array("a", 1), leftFieldSchema) -> + null, + new GenericRowWithSchema(Array("b", 2), leftFieldSchema) -> + new GenericRowWithSchema(Array("x", 2), rightFieldSchema) + ) + assert(result == expected) + } } class DatasetLargeResultCollectingSuite extends QueryTest --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org