Repository: spark
Updated Branches:
  refs/heads/branch-2.2 db21b6793 -> 770fd2a23


[SPARK-21300][SQL] ExternalMapToCatalyst should null-check map key prior to 
converting to internal value.

## What changes were proposed in this pull request?

`ExternalMapToCatalyst` should null-check map key prior to converting to 
internal value to throw an appropriate Exception instead of something like NPE.

## How was this patch tested?

Added a test and existing tests.

Author: Takuya UESHIN <[email protected]>

Closes #18524 from ueshin/issues/SPARK-21300.

(cherry picked from commit ce10545d3401c555e56a214b7c2f334274803660)
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/770fd2a2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/770fd2a2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/770fd2a2

Branch: refs/heads/branch-2.2
Commit: 770fd2a239798d3fa1cb4223d73cfc57413c0bb8
Parents: db21b67
Author: Takuya UESHIN <[email protected]>
Authored: Wed Jul 5 11:24:38 2017 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Wed Jul 5 11:24:55 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala      |  1 +
 .../apache/spark/sql/catalyst/ScalaReflection.scala |  1 +
 .../sql/catalyst/expressions/objects/objects.scala  | 16 +++++++++++++++-
 .../catalyst/encoders/ExpressionEncoderSuite.scala  |  8 +++++++-
 4 files changed, 24 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/770fd2a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 86a73a3..2698fae 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -423,6 +423,7 @@ object JavaTypeInference {
             inputObject,
             ObjectType(keyType.getRawType),
             serializerFor(_, keyType),
+            keyNullable = true,
             ObjectType(valueType.getRawType),
             serializerFor(_, valueType),
             valueNullable = true

http://git-wip-us.apache.org/repos/asf/spark/blob/770fd2a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 6d1d019..c887634 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -511,6 +511,7 @@ object ScalaReflection extends ScalaReflection {
           inputObject,
           dataTypeFor(keyType),
           serializerFor(_, keyType, keyPath, seenTypeSet),
+          keyNullable = !keyType.typeSymbol.asClass.isPrimitive,
           dataTypeFor(valueType),
           serializerFor(_, valueType, valuePath, seenTypeSet),
           valueNullable = !valueType.typeSymbol.asClass.isPrimitive)

http://git-wip-us.apache.org/repos/asf/spark/blob/770fd2a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index bedc88e..43cef6c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -659,18 +659,21 @@ object ExternalMapToCatalyst {
       inputMap: Expression,
       keyType: DataType,
       keyConverter: Expression => Expression,
+      keyNullable: Boolean,
       valueType: DataType,
       valueConverter: Expression => Expression,
       valueNullable: Boolean): ExternalMapToCatalyst = {
     val id = curId.getAndIncrement()
     val keyName = "ExternalMapToCatalyst_key" + id
+    val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id
     val valueName = "ExternalMapToCatalyst_value" + id
     val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id
 
     ExternalMapToCatalyst(
       keyName,
+      keyIsNull,
       keyType,
-      keyConverter(LambdaVariable(keyName, "false", keyType, false)),
+      keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)),
       valueName,
       valueIsNull,
       valueType,
@@ -686,6 +689,8 @@ object ExternalMapToCatalyst {
  *
  * @param key the name of the map key variable that used when iterate the map, 
and used as input for
  *            the `keyConverter`
+ * @param keyIsNull the nullability of the map key variable that used when 
iterate the map, and
+ *                  used as input for the `keyConverter`
  * @param keyType the data type of the map key variable that used when iterate 
the map, and used as
  *                input for the `keyConverter`
  * @param keyConverter A function that take the `key` as input, and converts 
it to catalyst format.
@@ -701,6 +706,7 @@ object ExternalMapToCatalyst {
  */
 case class ExternalMapToCatalyst private(
     key: String,
+    keyIsNull: String,
     keyType: DataType,
     keyConverter: Expression,
     value: String,
@@ -731,6 +737,7 @@ case class ExternalMapToCatalyst private(
 
     val keyElementJavaType = ctx.javaType(keyType)
     val valueElementJavaType = ctx.javaType(valueType)
+    ctx.addMutableState("boolean", keyIsNull, "")
     ctx.addMutableState(keyElementJavaType, key, "")
     ctx.addMutableState("boolean", valueIsNull, "")
     ctx.addMutableState(valueElementJavaType, value, "")
@@ -768,6 +775,12 @@ case class ExternalMapToCatalyst private(
         defineEntries -> defineKeyValue
     }
 
+    val keyNullCheck = if (ctx.isPrimitiveType(keyType)) {
+      s"$keyIsNull = false;"
+    } else {
+      s"$keyIsNull = $key == null;"
+    }
+
     val valueNullCheck = if (ctx.isPrimitiveType(valueType)) {
       s"$valueIsNull = false;"
     } else {
@@ -790,6 +803,7 @@ case class ExternalMapToCatalyst private(
           $defineEntries
           while($entries.hasNext()) {
             $defineKeyValue
+            $keyNullCheck
             $valueNullCheck
 
             ${genKeyConverter.code}

http://git-wip-us.apache.org/repos/asf/spark/blob/770fd2a2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 080f11b..bb1955a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -355,12 +355,18 @@ class ExpressionEncoderSuite extends PlanTest with 
AnalysisTest {
     checkNullable[String](true)
   }
 
-  test("null check for map key") {
+  test("null check for map key: String") {
     val encoder = ExpressionEncoder[Map[String, Int]]()
     val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 
2))))
     assert(e.getMessage.contains("Cannot use null as map key"))
   }
 
+  test("null check for map key: Integer") {
+    val encoder = ExpressionEncoder[Map[Integer, String]]()
+    val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, 
"b"))))
+    assert(e.getMessage.contains("Cannot use null as map key"))
+  }
+
   private def encodeDecodeTest[T : ExpressionEncoder](
       input: T,
       testName: String): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to