Repository: spark Updated Branches: refs/heads/master 4852b7d44 -> 873f3ad2b
[SPARK-16167][SQL] RowEncoder should preserve array/map type nullability. ## What changes were proposed in this pull request? Currently `RowEncoder` doesn't preserve nullability of `ArrayType` or `MapType`. It returns always `containsNull = true` for `ArrayType`, `valueContainsNull = true` for `MapType` and also the nullability of itself is always `true`. This pr fixes the nullability of them. ## How was this patch tested? Add tests to check if `RowEncoder` preserves array/map nullability. Author: Takuya UESHIN <[email protected]> Author: Takuya UESHIN <[email protected]> Closes #13873 from ueshin/issues/SPARK-16167. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/873f3ad2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/873f3ad2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/873f3ad2 Branch: refs/heads/master Commit: 873f3ad2b89c955f42fced49dc129e8efa77d044 Parents: 4852b7d Author: Takuya UESHIN <[email protected]> Authored: Wed Jul 5 20:32:47 2017 +0800 Committer: Wenchen Fan <[email protected]> Committed: Wed Jul 5 20:32:47 2017 +0800 ---------------------------------------------------------------------- .../sql/catalyst/encoders/RowEncoder.scala | 25 ++++++++++++--- .../sql/catalyst/encoders/RowEncoderSuite.scala | 33 ++++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/873f3ad2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index cc32fac..43c35bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -123,7 +123,7 @@ object RowEncoder { inputObject :: Nil, returnNullable = false) - case t @ ArrayType(et, cn) => + case t @ ArrayType(et, containsNull) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => StaticInvoke( @@ -132,8 +132,16 @@ object RowEncoder { "toArrayData", inputObject :: Nil, returnNullable = false) + case _ => MapObjects( - element => serializerFor(ValidateExternalType(element, et), et), + element => { + val value = serializerFor(ValidateExternalType(element, et), et) + if (!containsNull) { + AssertNotNull(value, Seq.empty) + } else { + value + } + }, inputObject, ObjectType(classOf[Object])) } @@ -155,10 +163,19 @@ object RowEncoder { ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) - NewInstance( + val nonNullOutput = NewInstance( classOf[ArrayBasedMapData], convertedKeys :: convertedValues :: Nil, - dataType = t) + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } case StructType(fields) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => http://git-wip-us.apache.org/repos/asf/spark/blob/873f3ad2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a5569a..6ed175f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -273,6 +273,39 @@ class RowEncoderSuite extends SparkFunSuite { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + for { + elementType <- Seq(IntegerType, StringType) + containsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve array nullability: " + + s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") { + val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + + for { + keyType <- Seq(IntegerType, StringType) + valueType <- Seq(IntegerType, StringType) + valueContainsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve map nullability: " + + s"MapType($keyType, $valueType, valueContainsNull = $valueContainsNull), " + + s"nullable = $nullable") { + val schema = new StructType().add( + "map", MapType(keyType, valueType, valueContainsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
