Repository: spark
Updated Branches:
refs/heads/master 6c6626614 -> 3bba8621c
[SPARK-22378][SQL] Eliminate redundant null check in generated code for
extracting an element from complex types
## What changes were proposed in this pull request?
This PR eliminates redundant null check in generated code for extracting an
element from complex types `GetArrayItem`, `GetMapValue`, and
`GetArrayStructFields`. Since these code generation does not take care of
`nullable` in `DataType` such as `ArrayType`, the generated code always has
`isNullAt(index)`.
This PR avoids to generate `isNullAt(index)` if `nullable` is false in
`DataType`.
Example
```
val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false))
checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1)
```
Before this PR
```
/* 034 */ public java.lang.Object apply(java.lang.Object _i) {
/* 035 */ InternalRow i = (InternalRow) _i;
/* 036 */
/* 037 */
/* 038 */
/* 039 */ boolean isNull = true;
/* 040 */ int value = -1;
/* 041 */
/* 042 */
/* 043 */
/* 044 */ isNull = false; // resultCode could change nullability.
/* 045 */
/* 046 */ final int index = (int) 0;
/* 047 */ if (index >= ((ArrayData) references[0]).numElements() || index <
0 || ((ArrayData) references[0]).isNullAt(index)) {
/* 048 */ isNull = true;
/* 049 */ } else {
/* 050 */ value = ((ArrayData) references[0]).getInt(index);
/* 051 */ }
/* 052 */ isNull_0 = isNull;
/* 053 */ value_0 = value;
/* 054 */
/* 055 */ // copy all the results into MutableRow
/* 056 */
/* 057 */ if (!isNull_0) {
/* 058 */ mutableRow.setInt(0, value_0);
/* 059 */ } else {
/* 060 */ mutableRow.setNullAt(0);
/* 061 */ }
/* 062 */
/* 063 */ return mutableRow;
/* 064 */ }
```
After this PR (Line 47 is changed)
```
/* 034 */ public java.lang.Object apply(java.lang.Object _i) {
/* 035 */ InternalRow i = (InternalRow) _i;
/* 036 */
/* 037 */
/* 038 */
/* 039 */ boolean isNull = true;
/* 040 */ int value = -1;
/* 041 */
/* 042 */
/* 043 */
/* 044 */ isNull = false; // resultCode could change nullability.
/* 045 */
/* 046 */ final int index = (int) 0;
/* 047 */ if (index >= ((ArrayData) references[0]).numElements() || index <
0) {
/* 048 */ isNull = true;
/* 049 */ } else {
/* 050 */ value = ((ArrayData) references[0]).getInt(index);
/* 051 */ }
/* 052 */ isNull_0 = isNull;
/* 053 */ value_0 = value;
/* 054 */
/* 055 */ // copy all the results into MutableRow
/* 056 */
/* 057 */ if (!isNull_0) {
/* 058 */ mutableRow.setInt(0, value_0);
/* 059 */ } else {
/* 060 */ mutableRow.setNullAt(0);
/* 061 */ }
/* 062 */
/* 063 */ return mutableRow;
/* 064 */ }
```
## How was this patch tested?
Added test cases into `ComplexTypeSuite`
Author: Kazuaki Ishizaki <[email protected]>
Closes #19598 from kiszk/SPARK-22378.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3bba8621
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3bba8621
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3bba8621
Branch: refs/heads/master
Commit: 3bba8621cf0a97f5c3134c9a160b1c8c5e97ba97
Parents: 6c66266
Author: Kazuaki Ishizaki <[email protected]>
Authored: Sat Nov 4 22:57:12 2017 -0700
Committer: gatorsmile <[email protected]>
Committed: Sat Nov 4 22:57:12 2017 -0700
----------------------------------------------------------------------
.../expressions/complexTypeExtractors.scala | 28 ++++++++++++++++----
.../catalyst/expressions/ComplexTypeSuite.scala | 11 ++++++--
2 files changed, 32 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3bba8621/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index ef88cfb..7e53ca3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -186,6 +186,16 @@ case class GetArrayStructFields(
val values = ctx.freshName("values")
val j = ctx.freshName("j")
val row = ctx.freshName("row")
+ val nullSafeEval = if (field.nullable) {
+ s"""
+ if ($row.isNullAt($ordinal)) {
+ $values[$j] = null;
+ } else
+ """
+ } else {
+ ""
+ }
+
s"""
final int $n = $eval.numElements();
final Object[] $values = new Object[$n];
@@ -194,9 +204,7 @@ case class GetArrayStructFields(
$values[$j] = null;
} else {
final InternalRow $row = $eval.getStruct($j, $numFields);
- if ($row.isNullAt($ordinal)) {
- $values[$j] = null;
- } else {
+ $nullSafeEval {
$values[$j] = ${ctx.getValue(row, field.dataType,
ordinal.toString)};
}
}
@@ -242,9 +250,14 @@ case class GetArrayItem(child: Expression, ordinal:
Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
+ val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull)
{
+ s" || $eval1.isNullAt($index)"
+ } else {
+ ""
+ }
s"""
final int $index = (int) $eval2;
- if ($index >= $eval1.numElements() || $index < 0 ||
$eval1.isNullAt($index)) {
+ if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
@@ -309,6 +322,11 @@ case class GetMapValue(child: Expression, key: Expression)
val found = ctx.freshName("found")
val key = ctx.freshName("key")
val values = ctx.freshName("values")
+ val nullCheck = if
(child.dataType.asInstanceOf[MapType].valueContainsNull) {
+ s" || $values.isNullAt($index)"
+ } else {
+ ""
+ }
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int $length = $eval1.numElements();
@@ -326,7 +344,7 @@ case class GetMapValue(child: Expression, key: Expression)
}
}
- if (!$found || $values.isNullAt($index)) {
+ if (!$found$nullCheck) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(values, dataType, index)};
http://git-wip-us.apache.org/repos/asf/spark/blob/3bba8621/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 5f8a8f4..b0eaad1 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -51,6 +51,9 @@ class ComplexTypeSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(GetArrayItem(array, nullInt), null)
checkEvaluation(GetArrayItem(nullArray, nullInt), null)
+ val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false))
+ checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1)
+
val nestedArray = Literal.create(Seq(Seq(1)),
ArrayType(ArrayType(IntegerType)))
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}
@@ -66,6 +69,9 @@ class ComplexTypeSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(GetMapValue(nullMap, nullString), null)
checkEvaluation(GetMapValue(map, nullString), null)
+ val nonNullMap = Literal.create(Map("a" -> 1), MapType(StringType,
IntegerType, false))
+ checkEvaluation(GetMapValue(nonNullMap, Literal("a")), 1)
+
val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")),
MapType(StringType, typeM))
checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
}
@@ -101,9 +107,10 @@ class ComplexTypeSuite extends SparkFunSuite with
ExpressionEvalHelper {
}
test("GetArrayStructFields") {
- val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val typeAS = ArrayType(StructType(StructField("a", IntegerType, false) ::
Nil))
+ val typeNullAS = ArrayType(StructType(StructField("a", IntegerType) ::
Nil))
val arrayStruct = Literal.create(Seq(create_row(1)), typeAS)
- val nullArrayStruct = Literal.create(null, typeAS)
+ val nullArrayStruct = Literal.create(null, typeNullAS)
def getArrayStructFields(expr: Expression, fieldName: String):
GetArrayStructFields = {
expr.dataType match {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]