Repository: spark
Updated Branches:
  refs/heads/master 5a8b978fa -> c1da4d421


[SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and 
ternary expression

The current implementation is sub-optimal:

* If an expression is always nullable, e.g. `Unhex`, we can still remove null 
check for children if they are not nullable.
* If an expression has some non-nullable children, we can still remove null 
check for these children and keep null check for others.

This PR improves this by making the null check elimination more fine-grained.

Author: Wenchen Fan <[email protected]>

Closes #10987 from cloud-fan/null-check.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c1da4d42
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c1da4d42
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c1da4d42

Branch: refs/heads/master
Commit: c1da4d421ab78772ffa52ad46e5bdfb4e5268f47
Parents: 5a8b978
Author: Wenchen Fan <[email protected]>
Authored: Sun Jan 31 22:43:03 2016 -0800
Committer: Davies Liu <[email protected]>
Committed: Sun Jan 31 22:43:03 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala   | 104 ++++++++++---------
 .../expressions/codegen/CodeGenerator.scala     |  32 ++++--
 .../spark/sql/catalyst/expressions/misc.scala   |  16 +--
 3 files changed, 85 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c1da4d42/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index db17ba7..353fb92 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression {
 
   /**
    * Called by unary expressions to generate a code block that returns null if 
its parent returns
-   * null, and if not not null, use `f` to generate the expression.
+   * null, and if not null, use `f` to generate the expression.
    *
    * As an example, the following does a boolean inversion (i.e. NOT).
    * {{{
@@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression {
 
   /**
    * Called by unary expressions to generate a code block that returns null if 
its parent returns
-   * null, and if not not null, use `f` to generate the expression.
+   * null, and if not null, use `f` to generate the expression.
    *
    * @param f function that accepts the non-null evaluation result name of 
child and returns Java
    *          code to compute the output.
@@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression {
       ctx: CodegenContext,
       ev: ExprCode,
       f: String => String): String = {
-    val eval = child.gen(ctx)
+    val childGen = child.gen(ctx)
+    val resultCode = f(childGen.value)
+
     if (nullable) {
-      eval.code + s"""
-        boolean ${ev.isNull} = ${eval.isNull};
+      val nullSafeEval = ctx.nullSafeExec(child.nullable, 
childGen.isNull)(resultCode)
+      s"""
+        ${childGen.code}
+        boolean ${ev.isNull} = ${childGen.isNull};
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        if (!${eval.isNull}) {
-          ${f(eval.value)}
-        }
+        $nullSafeEval
       """
     } else {
       ev.isNull = "false"
-      eval.code + s"""
+      s"""
+        ${childGen.code}
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        ${f(eval.value)}
+        $resultCode
       """
     }
   }
@@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression {
       ctx: CodegenContext,
       ev: ExprCode,
       f: (String, String) => String): String = {
-    val eval1 = left.gen(ctx)
-    val eval2 = right.gen(ctx)
-    val resultCode = f(eval1.value, eval2.value)
+    val leftGen = left.gen(ctx)
+    val rightGen = right.gen(ctx)
+    val resultCode = f(leftGen.value, rightGen.value)
+
     if (nullable) {
+      val nullSafeEval =
+        leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
+          rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
+            s"""
+              ${ev.isNull} = false; // resultCode could change nullability.
+              $resultCode
+            """
+          }
+      }
+
       s"""
-        ${eval1.code}
-        boolean ${ev.isNull} = ${eval1.isNull};
+        boolean ${ev.isNull} = true;
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        if (!${ev.isNull}) {
-          ${eval2.code}
-          if (!${eval2.isNull}) {
-            $resultCode
-          } else {
-            ${ev.isNull} = true;
-          }
-        }
+        $nullSafeEval
       """
-
     } else {
       ev.isNull = "false"
       s"""
-        ${eval1.code}
-        ${eval2.code}
+        ${leftGen.code}
+        ${rightGen.code}
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
         $resultCode
       """
@@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression {
 
   /**
    * Default behavior of evaluation according to the default nullability of 
TernaryExpression.
-   * If subclass of BinaryExpression override nullable, probably should also 
override this.
+   * If subclass of TernaryExpression override nullable, probably should also 
override this.
    */
   override def eval(input: InternalRow): Any = {
     val exprs = children
@@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression {
     sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
 
   /**
-   * Short hand for generating binary evaluation code.
+   * Short hand for generating ternary evaluation code.
    * If either of the sub-expressions is null, the result of this computation
    * is assumed to be null.
    *
-   * @param f accepts two variable names and returns Java code to compute the 
output.
+   * @param f accepts three variable names and returns Java code to compute 
the output.
    */
   protected def defineCodeGen(
     ctx: CodegenContext,
@@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression {
   }
 
   /**
-   * Short hand for generating binary evaluation code.
+   * Short hand for generating ternary evaluation code.
    * If either of the sub-expressions is null, the result of this computation
    * is assumed to be null.
    *
-   * @param f function that accepts the 2 non-null evaluation result names of 
children
+   * @param f function that accepts the 3 non-null evaluation result names of 
children
    *          and returns Java code to compute the output.
    */
   protected def nullSafeCodeGen(
     ctx: CodegenContext,
     ev: ExprCode,
     f: (String, String, String) => String): String = {
-    val evals = children.map(_.gen(ctx))
-    val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
+    val leftGen = children(0).gen(ctx)
+    val midGen = children(1).gen(ctx)
+    val rightGen = children(2).gen(ctx)
+    val resultCode = f(leftGen.value, midGen.value, rightGen.value)
+
     if (nullable) {
+      val nullSafeEval =
+        leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
+          midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
+            rightGen.code + ctx.nullSafeExec(children(2).nullable, 
rightGen.isNull) {
+              s"""
+                ${ev.isNull} = false; // resultCode could change nullability.
+                $resultCode
+              """
+            }
+          }
+      }
+
       s"""
-        ${evals(0).code}
         boolean ${ev.isNull} = true;
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        if (!${evals(0).isNull}) {
-          ${evals(1).code}
-          if (!${evals(1).isNull}) {
-            ${evals(2).code}
-            if (!${evals(2).isNull}) {
-              ${ev.isNull} = false;  // resultCode could change nullability
-              $resultCode
-            }
-          }
-        }
+        $nullSafeEval
       """
     } else {
       ev.isNull = "false"
       s"""
-        ${evals(0).code}
-        ${evals(1).code}
-        ${evals(2).code}
+        ${leftGen.code}
+        ${midGen.code}
+        ${rightGen.code}
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
         $resultCode
       """

http://git-wip-us.apache.org/repos/asf/spark/blob/c1da4d42/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 21f9198..a30aba1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -402,18 +402,38 @@ class CodegenContext {
   }
 
   /**
-    * Generates code for greater of two expressions.
-    *
-    * @param dataType data type of the expressions
-    * @param c1 name of the variable of expression 1's output
-    * @param c2 name of the variable of expression 2's output
-    */
+   * Generates code for greater of two expressions.
+   *
+   * @param dataType data type of the expressions
+   * @param c1 name of the variable of expression 1's output
+   * @param c2 name of the variable of expression 2's output
+   */
   def genGreater(dataType: DataType, c1: String, c2: String): String = 
javaType(dataType) match {
     case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
     case _ => s"(${genComp(dataType, c1, c2)}) > 0"
   }
 
   /**
+   * Generates code to do null safe execution, i.e. only execute the code when 
the input is not
+   * null by adding null check if necessary.
+   *
+   * @param nullable used to decide whether we should add null check or not.
+   * @param isNull the code to check if the input is null.
+   * @param execute the code that should only be executed when the input is 
not null.
+   */
+  def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String 
= {
+    if (nullable) {
+      s"""
+        if (!$isNull) {
+          $execute
+        }
+      """
+    } else {
+      "\n" + execute
+    }
+  }
+
+  /**
    * List of java data types that have special accessors and setters in 
[[InternalRow]].
    */
   val primitiveTypes =

http://git-wip-us.apache.org/repos/asf/spark/blob/c1da4d42/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 8480c3f..36e1fa1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: 
Int) extends Expression
     ev.isNull = "false"
     val childrenHash = children.map { child =>
       val childGen = child.gen(ctx)
-      childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
+      childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
         computeHash(childGen.value, child.dataType, ev.value, ctx)
       }
     }.mkString("\n")
@@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: 
Int) extends Expression
     """
   }
 
-  private def generateNullCheck(nullable: Boolean, isNull: String)(execution: 
String): String = {
-    if (nullable) {
-      s"""
-        if (!$isNull) {
-          $execution
-        }
-      """
-    } else {
-      "\n" + execution
-    }
-  }
-
   private def nullSafeElementHash(
       input: String,
       index: String,
@@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: 
Int) extends Expression
       ctx: CodegenContext): String = {
     val element = ctx.freshName("element")
 
-    generateNullCheck(nullable, s"$input.isNullAt($index)") {
+    ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
       s"""
         final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, 
elementType, index)};
         ${computeHash(element, elementType, result, ctx)}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to