Repository: spark
Updated Branches:
  refs/heads/branch-1.6 092a9c896 -> bd94793eb


[SPARK-11921][SQL] fix `nullable` of encoder schema

Author: Wenchen Fan <[email protected]>

Closes #9906 from cloud-fan/nullable.

(cherry picked from commit f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd94793e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd94793e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd94793e

Branch: refs/heads/branch-1.6
Commit: bd94793ebe8393ccec2f973566fd834fc55ac54d
Parents: 092a9c8
Author: Wenchen Fan <[email protected]>
Authored: Mon Nov 23 10:15:40 2015 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Mon Nov 23 10:16:02 2015 -0800

----------------------------------------------------------------------
 .../catalyst/encoders/ExpressionEncoder.scala   | 15 ++++++--
 .../encoders/ExpressionEncoderSuite.scala       | 38 +++++++++++++++++++-
 2 files changed, 50 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bd94793e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 6eeba14..7bc9aed 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -54,8 +54,13 @@ object ExpressionEncoder {
     val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
     val fromRowExpression = ScalaReflection.constructorFor[T]
 
+    val schema = ScalaReflection.schemaFor[T] match {
+      case ScalaReflection.Schema(s: StructType, _) => s
+      case ScalaReflection.Schema(dt, nullable) => new 
StructType().add("value", dt, nullable)
+    }
+
     new ExpressionEncoder[T](
-      toRowExpression.dataType,
+      schema,
       flat,
       toRowExpression.flatten,
       fromRowExpression,
@@ -71,7 +76,13 @@ object ExpressionEncoder {
     encoders.foreach(_.assertUnresolved())
 
     val schema = StructType(encoders.zipWithIndex.map {
-      case (e, i) => StructField(s"_${i + 1}", if (e.flat) 
e.schema.head.dataType else e.schema)
+      case (e, i) =>
+        val (dataType, nullable) = if (e.flat) {
+          e.schema.head.dataType -> e.schema.head.nullable
+        } else {
+          e.schema -> true
+        }
+        StructField(s"_${i + 1}", dataType, nullable)
     })
 
     val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94793e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 76459b3..d6ca138 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
-import org.apache.spark.sql.types.ArrayType
+import org.apache.spark.sql.types.{StructType, ArrayType}
 
 case class RepeatedStruct(s: Seq[PrimitiveData])
 
@@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite {
     ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
   }
 
+  test("nullable of encoder schema") {
+    def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
+      assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === 
nullable.toSeq)
+    }
+
+    // test for flat encoders
+    checkNullable[Int](false)
+    checkNullable[Option[Int]](true)
+    checkNullable[java.lang.Integer](true)
+    checkNullable[String](true)
+
+    // test for product encoders
+    checkNullable[(String, Int)](true, false)
+    checkNullable[(Int, java.lang.Long)](false, true)
+
+    // test for nested product encoders
+    {
+      val schema = ExpressionEncoder[(Int, (String, Int))].schema
+      assert(schema(0).nullable === false)
+      assert(schema(1).nullable === true)
+      assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true)
+      assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false)
+    }
+
+    // test for tupled encoders
+    {
+      val schema = ExpressionEncoder.tuple(
+        ExpressionEncoder[Int],
+        ExpressionEncoder[(String, Int)]).schema
+      assert(schema(0).nullable === false)
+      assert(schema(1).nullable === true)
+      assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true)
+      assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false)
+    }
+  }
+
   private val outers: ConcurrentMap[String, AnyRef] = new 
MapMaker().weakValues().makeMap()
   outers.put(getClass.getName, this)
   private def encodeDecodeTest[T : ExpressionEncoder](


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to