Repository: spark
Updated Branches:
  refs/heads/master 6c6626614 -> 3bba8621c


[SPARK-22378][SQL] Eliminate redundant null check in generated code for 
extracting an element from complex types

## What changes were proposed in this pull request?

This PR eliminates redundant null check in generated code for extracting an 
element from complex types `GetArrayItem`, `GetMapValue`, and 
`GetArrayStructFields`. Since these code generation does not take care of 
`nullable` in `DataType` such as `ArrayType`, the generated code always has 
`isNullAt(index)`.
This PR avoids to generate `isNullAt(index)` if `nullable` is false in 
`DataType`.

Example
```
    val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false))
    checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1)
```

Before this PR
```
/* 034 */   public java.lang.Object apply(java.lang.Object _i) {
/* 035 */     InternalRow i = (InternalRow) _i;
/* 036 */
/* 037 */
/* 038 */
/* 039 */     boolean isNull = true;
/* 040 */     int value = -1;
/* 041 */
/* 042 */
/* 043 */
/* 044 */     isNull = false; // resultCode could change nullability.
/* 045 */
/* 046 */     final int index = (int) 0;
/* 047 */     if (index >= ((ArrayData) references[0]).numElements() || index < 
0 || ((ArrayData) references[0]).isNullAt(index)) {
/* 048 */       isNull = true;
/* 049 */     } else {
/* 050 */       value = ((ArrayData) references[0]).getInt(index);
/* 051 */     }
/* 052 */     isNull_0 = isNull;
/* 053 */     value_0 = value;
/* 054 */
/* 055 */     // copy all the results into MutableRow
/* 056 */
/* 057 */     if (!isNull_0) {
/* 058 */       mutableRow.setInt(0, value_0);
/* 059 */     } else {
/* 060 */       mutableRow.setNullAt(0);
/* 061 */     }
/* 062 */
/* 063 */     return mutableRow;
/* 064 */   }
```

After this PR (Line 47 is changed)
```
/* 034 */   public java.lang.Object apply(java.lang.Object _i) {
/* 035 */     InternalRow i = (InternalRow) _i;
/* 036 */
/* 037 */
/* 038 */
/* 039 */     boolean isNull = true;
/* 040 */     int value = -1;
/* 041 */
/* 042 */
/* 043 */
/* 044 */     isNull = false; // resultCode could change nullability.
/* 045 */
/* 046 */     final int index = (int) 0;
/* 047 */     if (index >= ((ArrayData) references[0]).numElements() || index < 
0) {
/* 048 */       isNull = true;
/* 049 */     } else {
/* 050 */       value = ((ArrayData) references[0]).getInt(index);
/* 051 */     }
/* 052 */     isNull_0 = isNull;
/* 053 */     value_0 = value;
/* 054 */
/* 055 */     // copy all the results into MutableRow
/* 056 */
/* 057 */     if (!isNull_0) {
/* 058 */       mutableRow.setInt(0, value_0);
/* 059 */     } else {
/* 060 */       mutableRow.setNullAt(0);
/* 061 */     }
/* 062 */
/* 063 */     return mutableRow;
/* 064 */   }
```

## How was this patch tested?

Added test cases into `ComplexTypeSuite`

Author: Kazuaki Ishizaki <[email protected]>

Closes #19598 from kiszk/SPARK-22378.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3bba8621
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3bba8621
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3bba8621

Branch: refs/heads/master
Commit: 3bba8621cf0a97f5c3134c9a160b1c8c5e97ba97
Parents: 6c66266
Author: Kazuaki Ishizaki <[email protected]>
Authored: Sat Nov 4 22:57:12 2017 -0700
Committer: gatorsmile <[email protected]>
Committed: Sat Nov 4 22:57:12 2017 -0700

----------------------------------------------------------------------
 .../expressions/complexTypeExtractors.scala     | 28 ++++++++++++++++----
 .../catalyst/expressions/ComplexTypeSuite.scala | 11 ++++++--
 2 files changed, 32 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3bba8621/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index ef88cfb..7e53ca3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -186,6 +186,16 @@ case class GetArrayStructFields(
       val values = ctx.freshName("values")
       val j = ctx.freshName("j")
       val row = ctx.freshName("row")
+      val nullSafeEval = if (field.nullable) {
+        s"""
+         if ($row.isNullAt($ordinal)) {
+           $values[$j] = null;
+         } else
+        """
+      } else {
+        ""
+      }
+
       s"""
         final int $n = $eval.numElements();
         final Object[] $values = new Object[$n];
@@ -194,9 +204,7 @@ case class GetArrayStructFields(
             $values[$j] = null;
           } else {
             final InternalRow $row = $eval.getStruct($j, $numFields);
-            if ($row.isNullAt($ordinal)) {
-              $values[$j] = null;
-            } else {
+            $nullSafeEval {
               $values[$j] = ${ctx.getValue(row, field.dataType, 
ordinal.toString)};
             }
           }
@@ -242,9 +250,14 @@ case class GetArrayItem(child: Expression, ordinal: 
Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
       val index = ctx.freshName("index")
+      val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) 
{
+        s" || $eval1.isNullAt($index)"
+      } else {
+        ""
+      }
       s"""
         final int $index = (int) $eval2;
-        if ($index >= $eval1.numElements() || $index < 0 || 
$eval1.isNullAt($index)) {
+        if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
           ${ev.isNull} = true;
         } else {
           ${ev.value} = ${ctx.getValue(eval1, dataType, index)};
@@ -309,6 +322,11 @@ case class GetMapValue(child: Expression, key: Expression)
     val found = ctx.freshName("found")
     val key = ctx.freshName("key")
     val values = ctx.freshName("values")
+    val nullCheck = if 
(child.dataType.asInstanceOf[MapType].valueContainsNull) {
+      s" || $values.isNullAt($index)"
+    } else {
+      ""
+    }
     nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
       s"""
         final int $length = $eval1.numElements();
@@ -326,7 +344,7 @@ case class GetMapValue(child: Expression, key: Expression)
           }
         }
 
-        if (!$found || $values.isNullAt($index)) {
+        if (!$found$nullCheck) {
           ${ev.isNull} = true;
         } else {
           ${ev.value} = ${ctx.getValue(values, dataType, index)};

http://git-wip-us.apache.org/repos/asf/spark/blob/3bba8621/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 5f8a8f4..b0eaad1 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -51,6 +51,9 @@ class ComplexTypeSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(GetArrayItem(array, nullInt), null)
     checkEvaluation(GetArrayItem(nullArray, nullInt), null)
 
+    val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false))
+    checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1)
+
     val nestedArray = Literal.create(Seq(Seq(1)), 
ArrayType(ArrayType(IntegerType)))
     checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
   }
@@ -66,6 +69,9 @@ class ComplexTypeSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(GetMapValue(nullMap, nullString), null)
     checkEvaluation(GetMapValue(map, nullString), null)
 
+    val nonNullMap = Literal.create(Map("a" -> 1), MapType(StringType, 
IntegerType, false))
+    checkEvaluation(GetMapValue(nonNullMap, Literal("a")), 1)
+
     val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), 
MapType(StringType, typeM))
     checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
   }
@@ -101,9 +107,10 @@ class ComplexTypeSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   test("GetArrayStructFields") {
-    val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+    val typeAS = ArrayType(StructType(StructField("a", IntegerType, false) :: 
Nil))
+    val typeNullAS = ArrayType(StructType(StructField("a", IntegerType) :: 
Nil))
     val arrayStruct = Literal.create(Seq(create_row(1)), typeAS)
-    val nullArrayStruct = Literal.create(null, typeAS)
+    val nullArrayStruct = Literal.create(null, typeNullAS)
 
     def getArrayStructFields(expr: Expression, fieldName: String): 
GetArrayStructFields = {
       expr.dataType match {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to