Repository: spark
Updated Branches:
  refs/heads/master 70c5549ee -> 38b9e6962


[SPARK-18284][SQL] Make ExpressionEncoder.serializer.nullable precise

## What changes were proposed in this pull request?

This PR makes `ExpressionEncoder.serializer.nullable` for flat encoder for a 
primitive type `false`. Since it is `true` for now, it is too conservative.
While `ExpressionEncoder.schema` has correct information (e.g. `<IntegerType, 
false>`), `serializer.head.nullable` of `ExpressionEncoder`, which got from 
`encoderFor[T]`, is always false. It is too conservative.

This is accomplished by checking whether a type is one of primitive types. If 
it is `true`, `nullable` should be `false`.

## How was this patch tested?

Added new tests for encoder and dataframe

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #15780 from kiszk/SPARK-18284.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/38b9e696
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/38b9e696
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/38b9e696

Branch: refs/heads/master
Commit: 38b9e69623c14a675b14639e8291f5d29d2a0bc3
Parents: 70c5549
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Fri Dec 2 12:30:13 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Dec 2 12:30:13 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  |  4 +-
 .../spark/sql/catalyst/ScalaReflection.scala    |  7 ++-
 .../catalyst/encoders/ExpressionEncoder.scala   |  7 +--
 .../expressions/ReferenceToExpressions.scala    |  2 +-
 .../catalyst/expressions/objects/objects.scala  | 24 +++++----
 .../encoders/ExpressionEncoderSuite.scala       | 19 ++++++-
 .../org/apache/spark/sql/DatasetSuite.scala     | 52 +++++++++++++++++++-
 .../sql/streaming/FileStreamSinkSuite.scala     |  2 +-
 8 files changed, 96 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 04f0cfc..7e8e4da 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -396,12 +396,14 @@ object JavaTypeInference {
 
         case _ if mapType.isAssignableFrom(typeToken) =>
           val (keyType, valueType) = mapKeyValueType(typeToken)
+
           ExternalMapToCatalyst(
             inputObject,
             ObjectType(keyType.getRawType),
             serializerFor(_, keyType),
             ObjectType(valueType.getRawType),
-            serializerFor(_, valueType)
+            serializerFor(_, valueType),
+            valueNullable = true
           )
 
         case other =>

http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 0aa21b9..6e20096 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -498,7 +498,8 @@ object ScalaReflection extends ScalaReflection {
           dataTypeFor(keyType),
           serializerFor(_, keyType, keyPath),
           dataTypeFor(valueType),
-          serializerFor(_, valueType, valuePath))
+          serializerFor(_, valueType, valuePath),
+          valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
 
       case t if t <:< localTypeOf[String] =>
         StaticInvoke(
@@ -590,7 +591,9 @@ object ScalaReflection extends ScalaReflection {
               "cannot be used as field name\n" + walkedTypePath.mkString("\n"))
           }
 
-          val fieldValue = Invoke(inputObject, fieldName, 
dataTypeFor(fieldType))
+          val fieldValue = Invoke(
+            AssertNotNull(inputObject, walkedTypePath), fieldName, 
dataTypeFor(fieldType),
+            returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
           val clsName = getClassNameFromType(fieldType)
           val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" 
+: walkedTypePath
           expressions.Literal(fieldName) :: serializerFor(fieldValue, 
fieldType, newPath) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9c4818d..3757ecc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -60,7 +60,7 @@ object ExpressionEncoder {
     val cls = mirror.runtimeClass(tpe)
     val flat = !ScalaReflection.definedByConstructorParams(tpe)
 
-    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], 
nullable = true)
+    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], 
nullable = !cls.isPrimitive)
     val nullSafeInput = if (flat) {
       inputObject
     } else {
@@ -71,10 +71,7 @@ object ExpressionEncoder {
     val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
     val deserializer = ScalaReflection.deserializerFor[T]
 
-    val schema = ScalaReflection.schemaFor[T] match {
-      case ScalaReflection.Schema(s: StructType, _) => s
-      case ScalaReflection.Schema(dt, nullable) => new 
StructType().add("value", dt, nullable)
-    }
+    val schema = serializer.dataType
 
     new ExpressionEncoder[T](
       schema,

http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
index 6c75a7a..2ca77e8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -74,7 +74,7 @@ case class ReferenceToExpressions(result: Expression, 
children: Seq[Expression])
         ctx.addMutableState("boolean", classChildVarIsNull, "")
 
         val classChildVar =
-          LambdaVariable(classChildVarName, classChildVarIsNull, 
child.dataType)
+          LambdaVariable(classChildVarName, classChildVarIsNull, 
child.dataType, child.nullable)
 
         val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
           s"${classChildVar.isNull} = ${childGen.isNull};"

http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e517ec1..a8aa1e7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -171,15 +171,18 @@ case class StaticInvoke(
  * @param arguments An optional list of expressions, whos evaluation will be 
passed to the function.
  * @param propagateNull When true, and any of the arguments is null, null will 
be returned instead
  *                      of calling the function.
+ * @param returnNullable When false, indicating the invoked method will always 
return
+ *                       non-null value.
  */
 case class Invoke(
     targetObject: Expression,
     functionName: String,
     dataType: DataType,
     arguments: Seq[Expression] = Nil,
-    propagateNull: Boolean = true) extends InvokeLike {
+    propagateNull: Boolean = true,
+    returnNullable : Boolean = true) extends InvokeLike {
 
-  override def nullable: Boolean = true
+  override def nullable: Boolean = targetObject.nullable || needNullCheck || 
returnNullable
   override def children: Seq[Expression] = targetObject +: arguments
 
   override def eval(input: InternalRow): Any =
@@ -405,13 +408,15 @@ case class WrapOption(child: Expression, optType: 
DataType)
  * A place holder for the loop variable used in [[MapObjects]].  This should 
never be constructed
  * manually, but will instead be passed into the provided lambda function.
  */
-case class LambdaVariable(value: String, isNull: String, dataType: DataType) 
extends LeafExpression
+case class LambdaVariable(
+    value: String,
+    isNull: String,
+    dataType: DataType,
+    nullable: Boolean = true) extends LeafExpression
   with Unevaluable with NonSQLExpression {
 
-  override def nullable: Boolean = true
-
   override def genCode(ctx: CodegenContext): ExprCode = {
-    ExprCode(code = "", value = value, isNull = isNull)
+    ExprCode(code = "", value = value, isNull = if (nullable) isNull else 
"false")
   }
 }
 
@@ -592,7 +597,8 @@ object ExternalMapToCatalyst {
       keyType: DataType,
       keyConverter: Expression => Expression,
       valueType: DataType,
-      valueConverter: Expression => Expression): ExternalMapToCatalyst = {
+      valueConverter: Expression => Expression,
+      valueNullable: Boolean): ExternalMapToCatalyst = {
     val id = curId.getAndIncrement()
     val keyName = "ExternalMapToCatalyst_key" + id
     val valueName = "ExternalMapToCatalyst_value" + id
@@ -601,11 +607,11 @@ object ExternalMapToCatalyst {
     ExternalMapToCatalyst(
       keyName,
       keyType,
-      keyConverter(LambdaVariable(keyName, "false", keyType)),
+      keyConverter(LambdaVariable(keyName, "false", keyType, false)),
       valueName,
       valueIsNull,
       valueType,
-      valueConverter(LambdaVariable(valueName, valueIsNull, valueType)),
+      valueConverter(LambdaVariable(valueName, valueIsNull, valueType, 
valueNullable)),
       inputMap
     )
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 4d896c2..080f11b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -24,7 +24,7 @@ import java.util.Arrays
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.runtime.universe.TypeTag
 
-import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.{Encoder, Encoders}
 import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
 import org.apache.spark.sql.catalyst.analysis.AnalysisTest
 import org.apache.spark.sql.catalyst.dsl.plans._
@@ -300,6 +300,11 @@ class ExpressionEncoderSuite extends PlanTest with 
AnalysisTest {
   encodeDecodeTest(
     ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value 
class")
 
+  encodeDecodeTest(Option(31), "option of int")
+  encodeDecodeTest(Option.empty[Int], "empty option of int")
+  encodeDecodeTest(Option("abc"), "option of string")
+  encodeDecodeTest(Option.empty[String], "empty option of string")
+
   productTest(("UDT", new ExamplePoint(0.1, 0.2)))
 
   test("nullable of encoder schema") {
@@ -338,6 +343,18 @@ class ExpressionEncoderSuite extends PlanTest with 
AnalysisTest {
     }
   }
 
+  test("nullable of encoder serializer") {
+    def checkNullable[T: Encoder](nullable: Boolean): Unit = {
+      assert(encoderFor[T].serializer.forall(_.nullable === nullable))
+    }
+
+    // test for flat encoders
+    checkNullable[Int](false)
+    checkNullable[Option[Int]](true)
+    checkNullable[java.lang.Integer](true)
+    checkNullable[String](true)
+  }
+
   test("null check for map key") {
     val encoder = ExpressionEncoder[Map[String, Int]]()
     val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 
2))))

http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 1174d73..d31c766 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -28,7 +28,10 @@ import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
+import org.apache.spark.sql.types._
+
+case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
+case class TestDataPoint2(x: Int, s: String)
 
 class DatasetSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -969,6 +972,53 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
     assert(dataset.collect() sameElements Array(resultValue, resultValue))
   }
 
+  test("SPARK-18284: Serializer should have correct nullable value") {
+    val df1 = Seq(1, 2, 3, 4).toDF
+    assert(df1.schema(0).nullable == false)
+    val df2 = Seq(Integer.valueOf(1), Integer.valueOf(2)).toDF
+    assert(df2.schema(0).nullable == true)
+
+    val df3 = Seq(Seq(1, 2), Seq(3, 4)).toDF
+    assert(df3.schema(0).nullable == true)
+    assert(df3.schema(0).dataType.asInstanceOf[ArrayType].containsNull == 
false)
+    val df4 = Seq(Seq("a", "b"), Seq("c", "d")).toDF
+    assert(df4.schema(0).nullable == true)
+    assert(df4.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true)
+
+    val df5 = Seq((0, 1.0), (2, 2.0)).toDF("id", "v")
+    assert(df5.schema(0).nullable == false)
+    assert(df5.schema(1).nullable == false)
+    val df6 = Seq((0, 1.0, "a"), (2, 2.0, "b")).toDF("id", "v1", "v2")
+    assert(df6.schema(0).nullable == false)
+    assert(df6.schema(1).nullable == false)
+    assert(df6.schema(2).nullable == true)
+
+    val df7 = (Tuple1(Array(1, 2, 3)) :: Nil).toDF("a")
+    assert(df7.schema(0).nullable == true)
+    assert(df7.schema(0).dataType.asInstanceOf[ArrayType].containsNull == 
false)
+
+    val df8 = (Tuple1(Array((null: Integer), (null: Integer))) :: 
Nil).toDF("a")
+    assert(df8.schema(0).nullable == true)
+    assert(df8.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true)
+
+    val df9 = (Tuple1(Map(2 -> 3)) :: Nil).toDF("m")
+    assert(df9.schema(0).nullable == true)
+    assert(df9.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == 
false)
+
+    val df10 = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("m")
+    assert(df10.schema(0).nullable == true)
+    assert(df10.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == 
true)
+
+    val df11 = Seq(TestDataPoint(1, 2.2, "a", null),
+                   TestDataPoint(3, 4.4, "null", (TestDataPoint2(33, 
"b")))).toDF
+    assert(df11.schema(0).nullable == false)
+    assert(df11.schema(1).nullable == false)
+    assert(df11.schema(2).nullable == true)
+    assert(df11.schema(3).nullable == true)
+    assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(0).nullable 
== false)
+    assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(1).nullable 
== true)
+  }
+
   Seq(true, false).foreach { eager =>
     def testCheckpointing(testName: String)(f: => Unit): Unit = {
       test(s"Dataset.checkpoint() - $testName (eager = $eager)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/38b9e696/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 09613ef..54efae3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -86,7 +86,7 @@ class FileStreamSinkSuite extends StreamTest {
 
       val outputDf = spark.read.parquet(outputDir)
       val expectedSchema = new StructType()
-        .add(StructField("value", IntegerType))
+        .add(StructField("value", IntegerType, nullable = false))
         .add(StructField("id", IntegerType))
       assert(outputDf.schema === expectedSchema)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to