Repository: spark
Updated Branches:
refs/heads/master 39872af88 -> c9d7d83ed
[SPARK-25388][TEST][SQL] Detect incorrect nullable of DataType in the result
## What changes were proposed in this pull request?
This PR can correctly cause assertion failure when incorrect nullable of
DataType in the result is generated by a target function to be tested.
Let us think the following example. In the future, a developer would write
incorrect code that returns unexpected result. We have to correctly cause fail
in this test since `valueContainsNull=false` while `expr` includes `null`.
However, without this PR, this test passes. This PR can correctly cause fail.
```
test("test TARGETFUNCTON") {
val expr = TARGETMAPFUNCTON()
// expr = UnsafeMap(3 -> 6, 7 -> null)
// expr.dataType = (IntegerType, IntegerType, false)
expected = Map(3 -> 6, 7 -> null)
checkEvaluation(expr, expected)
```
In
[`checkEvaluationWithUnsafeProjection`](https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala#L208-L235),
the results are compared using `UnsafeRow`. When the given `expected` is
[converted](https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala#L226-L227))
to `UnsafeRow` using the `DataType` of `expr`.
```
val expectedRow = UnsafeProjection.create(Array(expression.dataType,
expression.dataType)).apply(lit)
```
In summary, `expr` is
`[0,1800000038,5000000038,18,2,0,700000003,2,0,6,18,2,0,700000003,2,0,6]` with
and w/o this PR. `expected` is converted to
* w/o this PR,
`[0,1800000038,5000000038,18,2,0,700000003,2,0,6,18,2,0,700000003,2,0,6]`
* with this PR,
`[0,1800000038,5000000038,18,2,0,700000003,2,2,6,18,2,0,700000003,2,2,6]`
As a result, w/o this PR, the test unexpectedly passes.
This is because, w/o this PR, based on given `dataType`, generated code of
projection for `expected` avoids to set nullbit.
```
// tmpInput_2 is expected
/* 155 */ for (int index_1 = 0; index_1 < numElements_1; index_1++) {
/* 156 */ mutableStateArray_1[1].write(index_1,
tmpInput_2.getInt(index_1));
/* 157 */ }
```
With this PR, generated code of projection for `expected` always checks whether
nullbit should be set by `isNullAt`
```
// tmpInput_2 is expected
/* 161 */ for (int index_1 = 0; index_1 < numElements_1; index_1++) {
/* 162 */
/* 163 */ if (tmpInput_2.isNullAt(index_1)) {
/* 164 */ mutableStateArray_1[1].setNull4Bytes(index_1);
/* 165 */ } else {
/* 166 */ mutableStateArray_1[1].write(index_1,
tmpInput_2.getInt(index_1));
/* 167 */ }
/* 168 */
/* 169 */ }
```
## How was this patch tested?
Existing UTs
Closes #22375 from kiszk/SPARK-25388.
Authored-by: Kazuaki Ishizaki <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c9d7d83e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c9d7d83e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c9d7d83e
Branch: refs/heads/master
Commit: c9d7d83ed5790aa272e969af36fd0cb90231111f
Parents: 39872af
Author: Kazuaki Ishizaki <[email protected]>
Authored: Fri Oct 12 11:14:35 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Oct 12 11:14:35 2018 +0800
----------------------------------------------------------------------
.../expressions/CodeGenerationSuite.scala | 14 +++---
.../expressions/ExpressionEvalHelper.scala | 46 +++++++++++++-------
.../expressions/ExpressionEvalHelperSuite.scala | 27 +++++++++++-
.../execution/ObjectHashAggregateSuite.scala | 2 +-
4 files changed, 64 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 5e8113a..7843003 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -113,7 +113,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
assert(actual.length == 1)
val expected = UTF8String.fromString("abc")
- if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+ if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual,
expected: $expected")
}
}
@@ -126,7 +126,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
assert(actual.length == 1)
val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true))
- if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+ if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual,
expected: $expected")
}
}
@@ -142,7 +142,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
assert(actual.length == 1)
val expected = ArrayBasedMapData((0 until length).toArray,
Array.fill(length)(true))
- if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+ if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual,
expected: $expected")
}
}
@@ -154,7 +154,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
val actual = plan(new
GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
- if (!checkResult(actual, expected, expressions.head.dataType)) {
+ if (!checkResult(actual, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual,
expected: $expected")
}
}
@@ -170,7 +170,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
assert(actual.length == 1)
val expected = InternalRow(Seq.fill(length)(true): _*)
- if (!checkResult(actual.head, expected, expressions.head.dataType)) {
+ if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual,
expected: $expected")
}
}
@@ -375,7 +375,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
assert(actualOr.length == 1)
val expectedOr = false
- if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) {
+ if (!checkResult(actualOr.head, expectedOr, exprOr)) {
fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr,
expected: $expectedOr")
}
@@ -389,7 +389,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
assert(actualAnd.length == 1)
val expectedAnd = false
- if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) {
+ if (!checkResult(actualAnd.head, expectedAnd, exprAnd)) {
fail(
s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd,
expected: $expectedAnd")
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index b5986aa..da18475 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -69,11 +69,22 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
/**
* Check the equality between result of expression and expected value, it
will handle
- * Array[Byte], Spread[Double], MapData and Row.
+ * Array[Byte], Spread[Double], MapData and Row. Also check whether nullable
in expression is
+ * true if result is null
*/
- protected def checkResult(result: Any, expected: Any, exprDataType:
DataType): Boolean = {
+ protected def checkResult(result: Any, expected: Any, expression:
Expression): Boolean = {
+ checkResult(result, expected, expression.dataType, expression.nullable)
+ }
+
+ protected def checkResult(
+ result: Any,
+ expected: Any,
+ exprDataType: DataType,
+ exprNullable: Boolean): Boolean = {
val dataType = UserDefinedType.sqlType(exprDataType)
+ // The result is null for a non-nullable expression
+ assert(result != null || exprNullable, "exprNullable should be true if
result is null")
(result, expected) match {
case (result: Array[Byte], expected: Array[Byte]) =>
java.util.Arrays.equals(result, expected)
@@ -83,24 +94,24 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
val st = dataType.asInstanceOf[StructType]
assert(result.numFields == st.length && expected.numFields ==
st.length)
st.zipWithIndex.forall { case (f, i) =>
- checkResult(result.get(i, f.dataType), expected.get(i, f.dataType),
f.dataType)
+ checkResult(
+ result.get(i, f.dataType), expected.get(i, f.dataType),
f.dataType, f.nullable)
}
case (result: ArrayData, expected: ArrayData) =>
result.numElements == expected.numElements && {
- val et = dataType.asInstanceOf[ArrayType].elementType
+ val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType]
var isSame = true
var i = 0
while (isSame && i < result.numElements) {
- isSame = checkResult(result.get(i, et), expected.get(i, et), et)
+ isSame = checkResult(result.get(i, et), expected.get(i, et), et,
cn)
i += 1
}
isSame
}
case (result: MapData, expected: MapData) =>
- val kt = dataType.asInstanceOf[MapType].keyType
- val vt = dataType.asInstanceOf[MapType].valueType
- checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) &&
- checkResult(result.valueArray, expected.valueArray, ArrayType(vt))
+ val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType]
+ checkResult(result.keyArray, expected.keyArray, ArrayType(kt, false),
false) &&
+ checkResult(result.valueArray, expected.valueArray, ArrayType(vt,
vcn), false)
case (result: Double, expected: Double) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
@@ -175,7 +186,7 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
val actual = try evaluateWithoutCodegen(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
- if (!checkResult(actual, expected, expression.dataType)) {
+ if (!checkResult(actual, expected, expression)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
@@ -191,7 +202,7 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
for (fallbackMode <- modes) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
val actual = evaluateWithMutableProjection(expression, inputRow)
- if (!checkResult(actual, expected, expression.dataType)) {
+ if (!checkResult(actual, expected, expression)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (fallback mode = $fallbackMode):
$expression, " +
s"actual: $actual, expected: $expected$input")
@@ -221,6 +232,12 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
+ val dataType = expression.dataType
+ if (!checkResult(unsafeRow.get(0, dataType), expected, dataType,
expression.nullable)) {
+ fail("Incorrect evaluation in unsafe mode (fallback mode =
$fallbackMode): " +
+ s"$expression, actual: $unsafeRow, expected: $expected, " +
+ s"dataType: $dataType, nullable: ${expression.nullable}")
+ }
if (expected == null) {
if (!unsafeRow.isNullAt(0)) {
val expectedRow = InternalRow(expected, expected)
@@ -229,8 +246,7 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
}
} else {
val lit = InternalRow(expected, expected)
- val expectedRow =
- UnsafeProjection.create(Array(expression.dataType,
expression.dataType)).apply(lit)
+ val expectedRow = UnsafeProjection.create(Array(dataType,
dataType)).apply(lit)
if (unsafeRow != expectedRow) {
fail(s"Incorrect evaluation in unsafe mode (fallback mode =
$fallbackMode): " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
@@ -280,7 +296,7 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
expression)
plan.initialize(0)
var actual = plan(inputRow).get(0, expression.dataType)
- assert(checkResult(actual, expected, expression.dataType))
+ assert(checkResult(actual, expected, expression))
plan = generateProject(
GenerateUnsafeProjection.generate(Alias(expression,
s"Optimized($expression)")() :: Nil),
@@ -288,7 +304,7 @@ trait ExpressionEvalHelper extends
GeneratorDrivenPropertyChecks with PlanTestBa
plan.initialize(0)
actual = FromUnsafeProjection(expression.dataType :: Nil)(
plan(inputRow)).get(0, expression.dataType)
- assert(checkResult(actual, expected, expression.dataType))
+ assert(checkResult(actual, expected, expression))
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
index 7c7c4cc..54ef964 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.types.{DataType, IntegerType, MapType}
/**
* A test suite for testing [[ExpressionEvalHelper]].
@@ -35,6 +36,13 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with
ExpressionEvalHelper
val e = intercept[RuntimeException] {
checkEvaluation(BadCodegenExpression(), 10) }
assert(e.getMessage.contains("some_variable"))
}
+
+ test("SPARK-25388: checkEvaluation should fail if nullable in DataType is
incorrect") {
+ val e = intercept[RuntimeException] {
+ checkEvaluation(MapIncorrectDataTypeExpression(), Map(3 -> 7, 6 -> null))
+ }
+ assert(e.getMessage.contains("and exprNullable was"))
+ }
}
/**
@@ -53,3 +61,18 @@ case class BadCodegenExpression() extends LeafExpression {
}
override def dataType: DataType = IntegerType
}
+
+/**
+ * An expression that returns a MapData with incorrect DataType whose
valueContainsNull is false
+ * while its value includes null
+ */
+case class MapIncorrectDataTypeExpression() extends LeafExpression with
CodegenFallback {
+ override def nullable: Boolean = false
+ override def eval(input: InternalRow): Any = {
+ val keys = new GenericArrayData(Array(3, 6))
+ val values = new GenericArrayData(Array(7, null))
+ new ArrayBasedMapData(keys, values)
+ }
+ // since values includes null, valueContainsNull must be true
+ override def dataType: DataType = MapType(IntegerType, IntegerType,
valueContainsNull = false)
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/c9d7d83e/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
index 0ef630b..c930919 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
@@ -416,7 +416,7 @@ class ObjectHashAggregateSuite
actual.zip(expected).foreach { case (lhs: Row, rhs: Row) =>
assert(lhs.length == rhs.length)
lhs.toSeq.zip(rhs.toSeq).foreach {
- case (a: Double, b: Double) => checkResult(a, b +- tolerance,
DoubleType)
+ case (a: Double, b: Double) => checkResult(a, b +- tolerance,
DoubleType, false)
case (a, b) => a == b
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]