Repository: spark
Updated Branches:
  refs/heads/master 78c8bd2e6 -> 736fc0393


[SPARK-25791][SQL] Datatype of serializers in RowEncoder should be accessible

## What changes were proposed in this pull request?

The serializers of `RowEncoder` use few `If` Catalyst expression which inherits 
`ComplexTypeMergingExpression` that will check input data types.

It is possible to generate serializers which fail the check and can't to access 
the data type of serializers. When producing If expression, we should use the 
same data type at its input expressions.

## How was this patch tested?

Added test.

Closes #22785 from viirya/SPARK-25791.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/736fc039
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/736fc039
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/736fc039

Branch: refs/heads/master
Commit: 736fc03930fd2087b2156d623705963acba13143
Parents: 78c8bd2
Author: Liang-Chi Hsieh <[email protected]>
Authored: Tue Oct 23 22:02:14 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Tue Oct 23 22:02:14 2018 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala  | 8 +++++---
 .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 8 ++++++++
 2 files changed, 13 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/736fc039/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 3340789..ae89f98 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -171,7 +171,7 @@ object RowEncoder {
 
       if (inputObject.nullable) {
         If(IsNull(inputObject),
-          Literal.create(null, inputType),
+          Literal.create(null, nonNullOutput.dataType),
           nonNullOutput)
       } else {
         nonNullOutput
@@ -187,7 +187,9 @@ object RowEncoder {
         val convertedField = if (field.nullable) {
           If(
             Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: 
Nil),
-            Literal.create(null, field.dataType),
+            // Because we strip UDTs, `field.dataType` can be different from 
`fieldValue.dataType`.
+            // We should use `fieldValue.dataType` here.
+            Literal.create(null, fieldValue.dataType),
             fieldValue
           )
         } else {
@@ -198,7 +200,7 @@ object RowEncoder {
 
       if (inputObject.nullable) {
         If(IsNull(inputObject),
-          Literal.create(null, inputType),
+          Literal.create(null, nonNullOutput.dataType),
           nonNullOutput)
       } else {
         nonNullOutput

http://git-wip-us.apache.org/repos/asf/spark/blob/736fc039/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 8d89f9c..2357321 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -273,6 +273,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
     assert(e4.getMessage.contains("java.lang.String is not a valid external 
type"))
   }
 
+  test("SPARK-25791: Datatype of serializers should be accessible") {
+    val udtSQLType = new StructType().add("a", IntegerType)
+    val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", 
"serializedPyClass")
+    val schema = new StructType().add("pythonUDT", pythonUDT, true)
+    val encoder = RowEncoder(schema)
+    assert(encoder.serializer(0).dataType == pythonUDT.sqlType)
+  }
+
   for {
     elementType <- Seq(IntegerType, StringType)
     containsNull <- Seq(true, false)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to