Repository: spark Updated Branches: refs/heads/master 20591afd7 -> 86761e10e
[SPARK-12478][SQL] Bugfix: Dataset fields of product types can't be null When creating extractors for product types (i.e. case classes and tuples), a null check is missing, thus we always assume input product values are non-null. This PR adds a null check in the extractor expression for product types. The null check is stripped off for top level product fields, which are mapped to the outermost `Row`s, since they can't be null. Thanks cloud-fan for helping investigating this issue! Author: Cheng Lian <[email protected]> Closes #10431 from liancheng/spark-12478.top-level-null-field. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86761e10 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86761e10 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86761e10 Branch: refs/heads/master Commit: 86761e10e145b6867cbe86b1e924ec237ba408af Parents: 20591af Author: Cheng Lian <[email protected]> Authored: Wed Dec 23 10:21:00 2015 +0800 Committer: Cheng Lian <[email protected]> Committed: Wed Dec 23 10:21:00 2015 +0800 ---------------------------------------------------------------------- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++---- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/86761e10/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index becd019..8a22b37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -380,7 +380,7 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil extractorFor(inputObject, tpe, walkedTypePath) match { - case s: CreateNamedStruct => s + case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } @@ -466,14 +466,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Product] => val params = getConstructorParameters(t) - - CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t http://git-wip-us.apache.org/repos/asf/spark/blob/86761e10/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3337996..7fe66e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -546,6 +546,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." )) } + + test("SPARK-12478: top level null field") { + val ds0 = Seq(NestedStruct(null)).toDS() + checkAnswer(ds0, NestedStruct(null)) + checkAnswer(ds0.toDF(), Row(null)) + + val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() + checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) + checkAnswer(ds1.toDF(), Row(Row(null))) + } } case class ClassData(a: String, b: Int) @@ -553,6 +563,7 @@ case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) case class NestedStruct(f: ClassData) +case class DeepNestedStruct(f: NestedStruct) /** * A class used to test serialization using encoders. This class throws exceptions when using --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
