Repository: spark
Updated Branches:
  refs/heads/master e1d72f2c0 -> c5583fdcd


[SPARK-23466][SQL] Remove redundant null checks in generated Java code by 
GenerateUnsafeProjection

## What changes were proposed in this pull request?

This PR works for one of TODOs in `GenerateUnsafeProjection` "if the 
nullability of field is correct, we can use it to save null check" to simplify 
generated code.
When `nullable=false` in `DataType`, `GenerateUnsafeProjection` removed code 
for null checks in the generated Java code.

## How was this patch tested?

Added new test cases into `GenerateUnsafeProjectionSuite`

Closes #20637 from kiszk/SPARK-23466.

Authored-by: Kazuaki Ishizaki <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5583fdc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5583fdc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5583fdc

Branch: refs/heads/master
Commit: c5583fdcd2289559ad98371475eb7288ced9b148
Parents: e1d72f2
Author: Kazuaki Ishizaki <[email protected]>
Authored: Sat Sep 1 12:19:19 2018 +0900
Committer: Takuya UESHIN <[email protected]>
Committed: Sat Sep 1 12:19:19 2018 +0900

----------------------------------------------------------------------
 .../codegen/GenerateUnsafeProjection.scala      | 77 ++++++++++++--------
 .../expressions/JsonExpressionsSuite.scala      |  2 +-
 .../codegen/GenerateUnsafeProjectionSuite.scala | 71 +++++++++++++++++-
 3 files changed, 117 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c5583fdc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 998a675..0ecd0de 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -32,6 +32,8 @@ import org.apache.spark.sql.types._
  */
 object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], 
UnsafeProjection] {
 
+  case class Schema(dataType: DataType, nullable: Boolean)
+
   /** Returns true iff we support this data type. */
   def canSupport(dataType: DataType): Boolean = 
UserDefinedType.sqlType(dataType) match {
     case NullType => true
@@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     case _ => false
   }
 
-  // TODO: if the nullability of field is correct, we can use it to save null 
check.
   private def writeStructToBuffer(
       ctx: CodegenContext,
       input: String,
       index: String,
-      fieldTypes: Seq[DataType],
+      schemas: Seq[Schema],
       rowWriter: String): String = {
     // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
     val tmpInput = ctx.freshName("tmpInput")
-    val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
-      ExprCode(
-        JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
-        JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), 
dt))
+    val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) 
=>
+      val isNull = if (nullable) {
+        JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)")
+      } else {
+        FalseLiteral
+      }
+      ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, 
dt, i.toString), dt))
     }
 
     val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
        |  // Remember the current cursor so that we can calculate how many 
bytes are
        |  // written later.
        |  final int $previousCursor = $rowWriter.cursor();
-       |  ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, 
structRowWriter)}
+       |  ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, 
structRowWriter)}
        |  $rowWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
        |}
      """.stripMargin
@@ -80,7 +84,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       ctx: CodegenContext,
       row: String,
       inputs: Seq[ExprCode],
-      inputTypes: Seq[DataType],
+      schemas: Seq[Schema],
       rowWriter: String,
       isTopLevel: Boolean = false): String = {
     val resetWriter = if (isTopLevel) {
@@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       s"$rowWriter.resetRowWriter();"
     }
 
-    val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
-      case ((input, dataType), index) =>
+    val writeFields = inputs.zip(schemas).zipWithIndex.map {
+      case ((input, Schema(dataType, nullable)), index) =>
         val dt = UserDefinedType.sqlType(dataType)
 
         val setNull = dt match {
@@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         }
 
         val writeField = writeElement(ctx, input.value, index.toString, dt, 
rowWriter)
-        if (input.isNull == FalseLiteral) {
+        if (!nullable) {
           s"""
              |${input.code}
              |${writeField.trim}
@@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
      """.stripMargin
   }
 
-  // TODO: if the nullability of array element is correct, we can use it to 
save null check.
   private def writeArrayToBuffer(
       ctx: CodegenContext,
       input: String,
       elementType: DataType,
+      containsNull: Boolean,
       rowWriter: String): String = {
     // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
     val tmpInput = ctx.freshName("tmpInput")
@@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
     val element = CodeGenerator.getValue(tmpInput, et, index)
 
+    val elementAssignment = if (containsNull) {
+      s"""
+         |if ($tmpInput.isNullAt($index)) {
+         |  $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
+         |} else {
+         |  ${writeElement(ctx, element, index, et, arrayWriter)}
+         |}
+       """.stripMargin
+    } else {
+      writeElement(ctx, element, index, et, arrayWriter)
+    }
+
     s"""
        |final ArrayData $tmpInput = $input;
        |if ($tmpInput instanceof UnsafeArrayData) {
@@ -179,23 +195,19 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
        |  $arrayWriter.initialize($numElements);
        |
        |  for (int $index = 0; $index < $numElements; $index++) {
-       |    if ($tmpInput.isNullAt($index)) {
-       |      $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
-       |    } else {
-       |      ${writeElement(ctx, element, index, et, arrayWriter)}
-       |    }
+       |    $elementAssignment
        |  }
        |}
      """.stripMargin
   }
 
-  // TODO: if the nullability of value element is correct, we can use it to 
save null check.
   private def writeMapToBuffer(
       ctx: CodegenContext,
       input: String,
       index: String,
       keyType: DataType,
       valueType: DataType,
+      valueContainsNull: Boolean,
       rowWriter: String): String = {
     // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
     val tmpInput = ctx.freshName("tmpInput")
@@ -203,6 +215,11 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val previousCursor = ctx.freshName("previousCursor")
 
     // Writes out unsafe map according to the format described in 
`UnsafeMapData`.
+    val keyArray = writeArrayToBuffer(
+      ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)
+    val valueArray = writeArrayToBuffer(
+      ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter)
+
     s"""
        |final MapData $tmpInput = $input;
        |if ($tmpInput instanceof UnsafeMapData) {
@@ -219,7 +236,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
        |  // Remember the current cursor so that we can write numBytes of key 
array later.
        |  final int $tmpCursor = $rowWriter.cursor();
        |
-       |  ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, 
rowWriter)}
+       |  $keyArray
        |
        |  // Write the numBytes of key array into the first 8 bytes.
        |  Platform.putLong(
@@ -227,7 +244,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
        |    $tmpCursor - 8,
        |    $rowWriter.cursor() - $tmpCursor);
        |
-       |  ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, 
rowWriter)}
+       |  $valueArray
        |  $rowWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
        |}
      """.stripMargin
@@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       dt: DataType,
       writer: String): String = dt match {
     case t: StructType =>
-      writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
+      writeStructToBuffer(
+        ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer)
 
-    case ArrayType(et, _) =>
+    case ArrayType(et, en) =>
       val previousCursor = ctx.freshName("previousCursor")
       s"""
          |// Remember the current cursor so that we can calculate how many 
bytes are
          |// written later.
          |final int $previousCursor = $writer.cursor();
-         |${writeArrayToBuffer(ctx, input, et, writer)}
+         |${writeArrayToBuffer(ctx, input, et, en, writer)}
          |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
        """.stripMargin
 
-    case MapType(kt, vt, _) =>
-      writeMapToBuffer(ctx, input, index, kt, vt, writer)
+    case MapType(kt, vt, vn) =>
+      writeMapToBuffer(ctx, input, index, kt, vt, vn, writer)
 
     case DecimalType.Fixed(precision, scale) =>
       s"$writer.write($index, $input, $precision, $scale);"
@@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       expressions: Seq[Expression],
       useSubexprElimination: Boolean = false): ExprCode = {
     val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
-    val exprTypes = expressions.map(_.dataType)
+    val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))
 
-    val numVarLenFields = exprTypes.count {
-      case dt if UnsafeRow.isFixedLength(dt) => false
+    val numVarLenFields = exprSchemas.count {
+      case Schema(dt, _) => !UnsafeRow.isFixedLength(dt)
       // TODO: consider large decimal and interval type
-      case _ => true
     }
 
     val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val evalSubexpr = ctx.subexprFunctions.mkString("\n")
 
     val writeExpressions = writeExpressionsToBuffer(
-      ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
+      ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)
 
     val code =
       code"""

http://git-wip-us.apache.org/repos/asf/spark/blob/c5583fdc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 04f1c8c..0e9c8ab 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -694,7 +694,7 @@ class JsonExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper with
           |""".stripMargin
         val jsonSchema = new StructType()
           .add("a", LongType, nullable = false)
-          .add("b", StringType, nullable = false)
+          .add("b", StringType, nullable = !forceJsonNullableSchema)
           .add("c", StringType, nullable = false)
         val output = InternalRow(1L, null, UTF8String.fromString("foo"))
         val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, 
StringType), gmtId)

http://git-wip-us.apache.org/repos/asf/spark/blob/c5583fdc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
index e9d21f8..01aa357 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.BoundReference
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
-import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
MapData}
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 class GenerateUnsafeProjectionSuite extends SparkFunSuite {
@@ -33,6 +33,41 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite {
     assert(!result.isNullAt(0))
     assert(result.getStruct(0, 1).isNullAt(0))
   }
+
+  test("Test unsafe projection for array/map/struct") {
+    val dataType1 = ArrayType(StringType, false)
+    val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil
+    val projection1 = GenerateUnsafeProjection.generate(exprs1)
+    val result1 = projection1.apply(AlwaysNonNull)
+    assert(!result1.isNullAt(0))
+    assert(!result1.getArray(0).isNullAt(0))
+    assert(!result1.getArray(0).isNullAt(1))
+    assert(!result1.getArray(0).isNullAt(2))
+
+    val dataType2 = MapType(StringType, StringType, false)
+    val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil
+    val projection2 = GenerateUnsafeProjection.generate(exprs2)
+    val result2 = projection2.apply(AlwaysNonNull)
+    assert(!result2.isNullAt(0))
+    assert(!result2.getMap(0).keyArray.isNullAt(0))
+    assert(!result2.getMap(0).keyArray.isNullAt(1))
+    assert(!result2.getMap(0).keyArray.isNullAt(2))
+    assert(!result2.getMap(0).valueArray.isNullAt(0))
+    assert(!result2.getMap(0).valueArray.isNullAt(1))
+    assert(!result2.getMap(0).valueArray.isNullAt(2))
+
+    val dataType3 = (new StructType)
+      .add("a", StringType, nullable = false)
+      .add("b", StringType, nullable = false)
+      .add("c", StringType, nullable = false)
+    val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil
+    val projection3 = GenerateUnsafeProjection.generate(exprs3)
+    val result3 = projection3.apply(InternalRow(AlwaysNonNull))
+    assert(!result3.isNullAt(0))
+    assert(!result3.getStruct(0, 1).isNullAt(0))
+    assert(!result3.getStruct(0, 2).isNullAt(0))
+    assert(!result3.getStruct(0, 3).isNullAt(0))
+  }
 }
 
 object AlwaysNull extends InternalRow {
@@ -59,3 +94,35 @@ object AlwaysNull extends InternalRow {
   override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
   private def notSupported: Nothing = throw new UnsupportedOperationException
 }
+
+object AlwaysNonNull extends InternalRow {
+  private def stringToUTF8Array(stringArray: Array[String]): ArrayData = {
+    val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray
+    ArrayData.toArrayData(utf8Array)
+  }
+  override def numFields: Int = 1
+  override def setNullAt(i: Int): Unit = {}
+  override def copy(): InternalRow = this
+  override def anyNull: Boolean = notSupported
+  override def isNullAt(ordinal: Int): Boolean = notSupported
+  override def update(i: Int, value: Any): Unit = notSupported
+  override def getBoolean(ordinal: Int): Boolean = notSupported
+  override def getByte(ordinal: Int): Byte = notSupported
+  override def getShort(ordinal: Int): Short = notSupported
+  override def getInt(ordinal: Int): Int = notSupported
+  override def getLong(ordinal: Int): Long = notSupported
+  override def getFloat(ordinal: Int): Float = notSupported
+  override def getDouble(ordinal: Int): Double = notSupported
+  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = 
notSupported
+  override def getUTF8String(ordinal: Int): UTF8String = 
UTF8String.fromString("test")
+  override def getBinary(ordinal: Int): Array[Byte] = notSupported
+  override def getInterval(ordinal: Int): CalendarInterval = notSupported
+  override def getStruct(ordinal: Int, numFields: Int): InternalRow = 
notSupported
+  override def getArray(ordinal: Int): ArrayData = 
stringToUTF8Array(Array("1", "2", "3"))
+  val keyArray = stringToUTF8Array(Array("1", "2", "3"))
+  val valueArray = stringToUTF8Array(Array("a", "b", "c"))
+  override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, 
valueArray)
+  override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
+  private def notSupported: Nothing = throw new UnsupportedOperationException
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to