Repository: spark Updated Branches: refs/heads/master 78c8bd2e6 -> 736fc0393
[SPARK-25791][SQL] Datatype of serializers in RowEncoder should be accessible ## What changes were proposed in this pull request? The serializers of `RowEncoder` use few `If` Catalyst expression which inherits `ComplexTypeMergingExpression` that will check input data types. It is possible to generate serializers which fail the check and can't to access the data type of serializers. When producing If expression, we should use the same data type at its input expressions. ## How was this patch tested? Added test. Closes #22785 from viirya/SPARK-25791. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/736fc039 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/736fc039 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/736fc039 Branch: refs/heads/master Commit: 736fc03930fd2087b2156d623705963acba13143 Parents: 78c8bd2 Author: Liang-Chi Hsieh <[email protected]> Authored: Tue Oct 23 22:02:14 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Tue Oct 23 22:02:14 2018 +0800 ---------------------------------------------------------------------- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 8 +++++--- .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/736fc039/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3340789..ae89f98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -171,7 +171,7 @@ object RowEncoder { if (inputObject.nullable) { If(IsNull(inputObject), - Literal.create(null, inputType), + Literal.create(null, nonNullOutput.dataType), nonNullOutput) } else { nonNullOutput @@ -187,7 +187,9 @@ object RowEncoder { val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), - Literal.create(null, field.dataType), + // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`. + // We should use `fieldValue.dataType` here. + Literal.create(null, fieldValue.dataType), fieldValue ) } else { @@ -198,7 +200,7 @@ object RowEncoder { if (inputObject.nullable) { If(IsNull(inputObject), - Literal.create(null, inputType), + Literal.create(null, nonNullOutput.dataType), nonNullOutput) } else { nonNullOutput http://git-wip-us.apache.org/repos/asf/spark/blob/736fc039/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 8d89f9c..2357321 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -273,6 +273,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + test("SPARK-25791: Datatype of serializers should be accessible") { + val udtSQLType = new StructType().add("a", IntegerType) + val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", "serializedPyClass") + val schema = new StructType().add("pythonUDT", pythonUDT, true) + val encoder = RowEncoder(schema) + assert(encoder.serializer(0).dataType == pythonUDT.sqlType) + } + for { elementType <- Seq(IntegerType, StringType) containsNull <- Seq(true, false) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
