Repository: spark
Updated Branches:
  refs/heads/master 39872af88 -> c9d7d83ed


[SPARK-25388][TEST][SQL] Detect incorrect nullable of DataType in the result

## What changes were proposed in this pull request?

This PR can correctly cause assertion failure when incorrect nullable of 
DataType in the result is generated by a target function to be tested.

Let us think the following example. In the future, a developer would write 
incorrect code that returns unexpected result. We have to correctly cause fail 
in this test since `valueContainsNull=false` while `expr` includes `null`. 
However, without this PR, this test passes. This PR can correctly cause fail.

```
test("test TARGETFUNCTON") {
  val expr = TARGETMAPFUNCTON()
  // expr = UnsafeMap(3 -> 6, 7 -> null)
  // expr.dataType = (IntegerType, IntegerType, false)

  expected = Map(3 -> 6, 7 -> null)
  checkEvaluation(expr, expected)
```

In 
[`checkEvaluationWithUnsafeProjection`](https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala#L208-L235),
 the results are compared using `UnsafeRow`. When the given `expected` is 
[converted](https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala#L226-L227))
 to `UnsafeRow` using the `DataType` of `expr`.
```
val expectedRow = UnsafeProjection.create(Array(expression.dataType, 
expression.dataType)).apply(lit)
```

In summary, `expr` is 
`[0,1800000038,5000000038,18,2,0,700000003,2,0,6,18,2,0,700000003,2,0,6]` with 
and w/o this PR. `expected` is converted to

* w/o  this PR, 
`[0,1800000038,5000000038,18,2,0,700000003,2,0,6,18,2,0,700000003,2,0,6]`
* with this PR, 
`[0,1800000038,5000000038,18,2,0,700000003,2,2,6,18,2,0,700000003,2,2,6]`

As a result, w/o this PR, the test unexpectedly passes.

This is because, w/o this PR, based on given `dataType`, generated code of 
projection for `expected` avoids to set nullbit.
```
                    // tmpInput_2 is expected
/* 155 */           for (int index_1 = 0; index_1 < numElements_1; index_1++) {
/* 156 */             mutableStateArray_1[1].write(index_1, 
tmpInput_2.getInt(index_1));
/* 157 */           }
```

With this PR, generated code of projection for `expected` always checks whether 
nullbit should be set by `isNullAt`
```
                    // tmpInput_2 is expected
/* 161 */           for (int index_1 = 0; index_1 < numElements_1; index_1++) {
/* 162 */
/* 163 */             if (tmpInput_2.isNullAt(index_1)) {
/* 164 */               mutableStateArray_1[1].setNull4Bytes(index_1);
/* 165 */             } else {
/* 166 */               mutableStateArray_1[1].write(index_1, 
tmpInput_2.getInt(index_1));
/* 167 */             }
/* 168 */
/* 169 */           }
```

## How was this patch tested?

Existing UTs

Closes #22375 from kiszk/SPARK-25388.

Authored-by: Kazuaki Ishizaki <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c9d7d83e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c9d7d83e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c9d7d83e

Branch: refs/heads/master
Commit: c9d7d83ed5790aa272e969af36fd0cb90231111f
Parents: 39872af
Author: Kazuaki Ishizaki <[email protected]>
Authored: Fri Oct 12 11:14:35 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Oct 12 11:14:35 2018 +0800

----------------------------------------------------------------------
 .../expressions/CodeGenerationSuite.scala       | 14 +++---
 .../expressions/ExpressionEvalHelper.scala      | 46 +++++++++++++-------
 .../expressions/ExpressionEvalHelperSuite.scala | 27 +++++++++++-
 .../execution/ObjectHashAggregateSuite.scala    |  2 +-
 4 files changed, 64 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 5e8113a..7843003 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -113,7 +113,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(actual.length == 1)
     val expected = UTF8String.fromString("abc")
 
-    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+    if (!checkResult(actual.head, expected, expressions.head)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, 
expected: $expected")
     }
   }
@@ -126,7 +126,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(actual.length == 1)
     val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true))
 
-    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+    if (!checkResult(actual.head, expected, expressions.head)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, 
expected: $expected")
     }
   }
@@ -142,7 +142,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(actual.length == 1)
     val expected = ArrayBasedMapData((0 until length).toArray, 
Array.fill(length)(true))
 
-    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+    if (!checkResult(actual.head, expected, expressions.head)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, 
expected: $expected")
     }
   }
@@ -154,7 +154,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     val actual = plan(new 
GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
     val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
 
-    if (!checkResult(actual, expected, expressions.head.dataType)) {
+    if (!checkResult(actual, expected, expressions.head)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, 
expected: $expected")
     }
   }
@@ -170,7 +170,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(actual.length == 1)
     val expected = InternalRow(Seq.fill(length)(true): _*)
 
-    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+    if (!checkResult(actual.head, expected, expressions.head)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, 
expected: $expected")
     }
   }
@@ -375,7 +375,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(actualOr.length == 1)
     val expectedOr = false
 
-    if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) {
+    if (!checkResult(actualOr.head, expectedOr, exprOr)) {
       fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr, 
expected: $expectedOr")
     }
 
@@ -389,7 +389,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(actualAnd.length == 1)
     val expectedAnd = false
 
-    if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) {
+    if (!checkResult(actualAnd.head, expectedAnd, exprAnd)) {
       fail(
         s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, 
expected: $expectedAnd")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index b5986aa..da18475 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -69,11 +69,22 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
 
   /**
    * Check the equality between result of expression and expected value, it 
will handle
-   * Array[Byte], Spread[Double], MapData and Row.
+   * Array[Byte], Spread[Double], MapData and Row. Also check whether nullable 
in expression is
+   * true if result is null
    */
-  protected def checkResult(result: Any, expected: Any, exprDataType: 
DataType): Boolean = {
+  protected def checkResult(result: Any, expected: Any, expression: 
Expression): Boolean = {
+    checkResult(result, expected, expression.dataType, expression.nullable)
+  }
+
+  protected def checkResult(
+      result: Any,
+      expected: Any,
+      exprDataType: DataType,
+      exprNullable: Boolean): Boolean = {
     val dataType = UserDefinedType.sqlType(exprDataType)
 
+    // The result is null for a non-nullable expression
+    assert(result != null || exprNullable, "exprNullable should be true if 
result is null")
     (result, expected) match {
       case (result: Array[Byte], expected: Array[Byte]) =>
         java.util.Arrays.equals(result, expected)
@@ -83,24 +94,24 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
         val st = dataType.asInstanceOf[StructType]
         assert(result.numFields == st.length && expected.numFields == 
st.length)
         st.zipWithIndex.forall { case (f, i) =>
-          checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), 
f.dataType)
+          checkResult(
+            result.get(i, f.dataType), expected.get(i, f.dataType), 
f.dataType, f.nullable)
         }
       case (result: ArrayData, expected: ArrayData) =>
         result.numElements == expected.numElements && {
-          val et = dataType.asInstanceOf[ArrayType].elementType
+          val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType]
           var isSame = true
           var i = 0
           while (isSame && i < result.numElements) {
-            isSame = checkResult(result.get(i, et), expected.get(i, et), et)
+            isSame = checkResult(result.get(i, et), expected.get(i, et), et, 
cn)
             i += 1
           }
           isSame
         }
       case (result: MapData, expected: MapData) =>
-        val kt = dataType.asInstanceOf[MapType].keyType
-        val vt = dataType.asInstanceOf[MapType].valueType
-        checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) &&
-          checkResult(result.valueArray, expected.valueArray, ArrayType(vt))
+        val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType]
+        checkResult(result.keyArray, expected.keyArray, ArrayType(kt, false), 
false) &&
+          checkResult(result.valueArray, expected.valueArray, ArrayType(vt, 
vcn), false)
       case (result: Double, expected: Double) =>
         if (expected.isNaN) result.isNaN else expected == result
       case (result: Float, expected: Float) =>
@@ -175,7 +186,7 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
     val actual = try evaluateWithoutCodegen(expression, inputRow) catch {
       case e: Exception => fail(s"Exception evaluating $expression", e)
     }
-    if (!checkResult(actual, expected, expression.dataType)) {
+    if (!checkResult(actual, expected, expression)) {
       val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
       fail(s"Incorrect evaluation (codegen off): $expression, " +
         s"actual: $actual, " +
@@ -191,7 +202,7 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
     for (fallbackMode <- modes) {
       withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
         val actual = evaluateWithMutableProjection(expression, inputRow)
-        if (!checkResult(actual, expected, expression.dataType)) {
+        if (!checkResult(actual, expected, expression)) {
           val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
           fail(s"Incorrect evaluation (fallback mode = $fallbackMode): 
$expression, " +
             s"actual: $actual, expected: $expected$input")
@@ -221,6 +232,12 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
         val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
         val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
 
+        val dataType = expression.dataType
+        if (!checkResult(unsafeRow.get(0, dataType), expected, dataType, 
expression.nullable)) {
+          fail("Incorrect evaluation in unsafe mode (fallback mode = 
$fallbackMode): " +
+            s"$expression, actual: $unsafeRow, expected: $expected, " +
+            s"dataType: $dataType, nullable: ${expression.nullable}")
+        }
         if (expected == null) {
           if (!unsafeRow.isNullAt(0)) {
             val expectedRow = InternalRow(expected, expected)
@@ -229,8 +246,7 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
           }
         } else {
           val lit = InternalRow(expected, expected)
-          val expectedRow =
-            UnsafeProjection.create(Array(expression.dataType, 
expression.dataType)).apply(lit)
+          val expectedRow = UnsafeProjection.create(Array(dataType, 
dataType)).apply(lit)
           if (unsafeRow != expectedRow) {
             fail(s"Incorrect evaluation in unsafe mode (fallback mode = 
$fallbackMode): " +
               s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
@@ -280,7 +296,7 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
       expression)
     plan.initialize(0)
     var actual = plan(inputRow).get(0, expression.dataType)
-    assert(checkResult(actual, expected, expression.dataType))
+    assert(checkResult(actual, expected, expression))
 
     plan = generateProject(
       GenerateUnsafeProjection.generate(Alias(expression, 
s"Optimized($expression)")() :: Nil),
@@ -288,7 +304,7 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks with PlanTestBa
     plan.initialize(0)
     actual = FromUnsafeProjection(expression.dataType :: Nil)(
       plan(inputRow)).get(0, expression.dataType)
-    assert(checkResult(actual, expected, expression.dataType))
+    assert(checkResult(actual, expected, expression))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
index 7c7c4cc..54ef964 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
CodegenFallback, ExprCode}
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.types.{DataType, IntegerType, MapType}
 
 /**
  * A test suite for testing [[ExpressionEvalHelper]].
@@ -35,6 +36,13 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with 
ExpressionEvalHelper
     val e = intercept[RuntimeException] { 
checkEvaluation(BadCodegenExpression(), 10) }
     assert(e.getMessage.contains("some_variable"))
   }
+
+  test("SPARK-25388: checkEvaluation should fail if nullable in DataType is 
incorrect") {
+    val e = intercept[RuntimeException] {
+      checkEvaluation(MapIncorrectDataTypeExpression(), Map(3 -> 7, 6 -> null))
+    }
+    assert(e.getMessage.contains("and exprNullable was"))
+  }
 }
 
 /**
@@ -53,3 +61,18 @@ case class BadCodegenExpression() extends LeafExpression {
   }
   override def dataType: DataType = IntegerType
 }
+
+/**
+ * An expression that returns a MapData with incorrect DataType whose 
valueContainsNull is false
+ * while its value includes null
+ */
+case class MapIncorrectDataTypeExpression() extends LeafExpression with 
CodegenFallback {
+  override def nullable: Boolean = false
+  override def eval(input: InternalRow): Any = {
+    val keys = new GenericArrayData(Array(3, 6))
+    val values = new GenericArrayData(Array(7, null))
+    new ArrayBasedMapData(keys, values)
+  }
+  // since values includes null, valueContainsNull must be true
+  override def dataType: DataType = MapType(IntegerType, IntegerType, 
valueContainsNull = false)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
index 0ef630b..c930919 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
@@ -416,7 +416,7 @@ class ObjectHashAggregateSuite
     actual.zip(expected).foreach { case (lhs: Row, rhs: Row) =>
       assert(lhs.length == rhs.length)
       lhs.toSeq.zip(rhs.toSeq).foreach {
-        case (a: Double, b: Double) => checkResult(a, b +- tolerance, 
DoubleType)
+        case (a: Double, b: Double) => checkResult(a, b +- tolerance, 
DoubleType, false)
         case (a, b) => a == b
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to