Repository: spark
Updated Branches:
  refs/heads/master 4117786a8 -> c7d014861


[SPARK-22600][SQL] Fix 64kb limit for deeply nested expressions under 
wholestage codegen

## What changes were proposed in this pull request?

SPARK-22543 fixes the 64kb compile error for deeply nested expression for 
non-wholestage codegen. This PR extends it to support wholestage codegen.

This patch brings some util methods in to extract necessary parameters for an 
expression if it is split to a function.

The util methods are put in object `ExpressionCodegen` under `codegen`. The 
main entry is `getExpressionInputParams` which returns all necessary parameters 
to evaluate the given expression in a split function.

This util methods can be used to split expressions too. This is a TODO item 
later.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <[email protected]>

Closes #19813 from viirya/reduce-expr-code-for-wholestage.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c7d01486
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c7d01486
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c7d01486

Branch: refs/heads/master
Commit: c7d0148615c921dca782ee3785b5d0cd59e42262
Parents: 4117786
Author: Liang-Chi Hsieh <[email protected]>
Authored: Wed Dec 13 10:40:05 2017 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Wed Dec 13 10:40:05 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala   |  37 ++-
 .../expressions/codegen/CodeGenerator.scala     |  44 ++-
 .../expressions/codegen/ExpressionCodegen.scala | 269 +++++++++++++++++++
 .../codegen/ExpressionCodegenSuite.scala        | 220 +++++++++++++++
 .../spark/sql/execution/ColumnarBatchScan.scala |   5 +-
 .../sql/execution/WholeStageCodegenSuite.scala  |  23 +-
 6 files changed, 585 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c7d01486/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 743782a..329ea5d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -105,6 +105,12 @@ abstract class Expression extends TreeNode[Expression] {
       val isNull = ctx.freshName("isNull")
       val value = ctx.freshName("value")
       val eval = doGenCode(ctx, ExprCode("", isNull, value))
+      eval.isNull = if (this.nullable) eval.isNull else "false"
+
+      // Records current input row and variables of this expression.
+      eval.inputRow = ctx.INPUT_ROW
+      eval.inputVars = findInputVars(ctx, eval)
+
       reduceCodeSize(ctx, eval)
       if (eval.code.nonEmpty) {
         // Add `this` in the comment.
@@ -115,9 +121,29 @@ abstract class Expression extends TreeNode[Expression] {
     }
   }
 
+  /**
+   * Returns the input variables to this expression.
+   */
+  private def findInputVars(ctx: CodegenContext, eval: ExprCode): 
Seq[ExprInputVar] = {
+    if (ctx.currentVars != null) {
+      this.collect {
+        case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != 
null =>
+          ExprInputVar(exprCode = ctx.currentVars(ordinal),
+            dataType = b.dataType, nullable = b.nullable)
+      }
+    } else {
+      Seq.empty
+    }
+  }
+
+  /**
+   * In order to prevent 64kb compile error, reducing the size of generated 
codes by
+   * separating it into a function if the size exceeds a threshold.
+   */
   private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
-    // TODO: support whole stage codegen too
-    if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && 
ctx.currentVars == null) {
+    lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this)
+
+    if (eval.code.trim.length > 1024 && funcParams.isDefined) {
       val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
         val globalIsNull = ctx.freshName("globalIsNull")
         ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull)
@@ -132,9 +158,12 @@ abstract class Expression extends TreeNode[Expression] {
       val newValue = ctx.freshName("value")
 
       val funcName = ctx.freshName(nodeName)
+      val callParams = funcParams.map(_._1.mkString(", ")).get
+      val declParams = funcParams.map(_._2.mkString(", ")).get
+
       val funcFullName = ctx.addNewFunction(funcName,
         s"""
-           |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
+           |private $javaType $funcName($declParams) {
            |  ${eval.code.trim}
            |  $setIsNull
            |  return ${eval.value};
@@ -142,7 +171,7 @@ abstract class Expression extends TreeNode[Expression] {
            """.stripMargin)
 
       eval.value = newValue
-      eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
+      eval.code = s"$javaType $newValue = $funcFullName($callParams);"
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c7d01486/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4b8b16f..257c3f1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -55,8 +55,24 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
  *                 to null.
  * @param value A term for a (possibly primitive) value of the result of the 
evaluation. Not
  *              valid if `isNull` is set to `true`.
+ * @param inputRow A term that holds the input row name when generating this 
code.
+ * @param inputVars A list of [[ExprInputVar]] that holds input variables when 
generating this code.
  */
-case class ExprCode(var code: String, var isNull: String, var value: String)
+case class ExprCode(
+    var code: String,
+    var isNull: String,
+    var value: String,
+    var inputRow: String = null,
+    var inputVars: Seq[ExprInputVar] = Seq.empty)
+
+/**
+ * Represents an input variable [[ExprCode]] to an evaluation of an 
[[Expression]].
+ *
+ * @param exprCode The [[ExprCode]] that represents the evaluation result for 
the input variable.
+ * @param dataType The data type of the input variable.
+ * @param nullable Whether the input variable can be null or not.
+ */
+case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: 
Boolean)
 
 /**
  * State used for subexpression elimination.
@@ -1012,16 +1028,25 @@ class CodegenContext {
     commonExprs.foreach { e =>
       val expr = e.head
       val fnName = freshName("evalExpr")
-      val isNull = s"${fnName}IsNull"
+      val isNull = if (expr.nullable) {
+        s"${fnName}IsNull"
+      } else {
+        ""
+      }
       val value = s"${fnName}Value"
 
       // Generate the code for this expression tree and wrap it in a function.
       val eval = expr.genCode(this)
+      val assignIsNull = if (expr.nullable) {
+        s"$isNull = ${eval.isNull};"
+      } else {
+        ""
+      }
       val fn =
         s"""
            |private void $fnName(InternalRow $INPUT_ROW) {
            |  ${eval.code.trim}
-           |  $isNull = ${eval.isNull};
+           |  $assignIsNull
            |  $value = ${eval.value};
            |}
            """.stripMargin
@@ -1039,12 +1064,17 @@ class CodegenContext {
       //   2. Less code.
       // Currently, we will do this for all non-leaf only expression trees 
(i.e. expr trees with
       // at least two nodes) as the cost of doing it is expected to be low.
-      addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;")
-      addMutableState(javaType(expr.dataType), value,
-        s"$value = ${defaultValue(expr.dataType)};")
+      if (expr.nullable) {
+        addMutableState(JAVA_BOOLEAN, isNull)
+      }
+      addMutableState(javaType(expr.dataType), value)
 
       subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
-      val state = SubExprEliminationState(isNull, value)
+      val state = if (expr.nullable) {
+        SubExprEliminationState(isNull, value)
+      } else {
+        SubExprEliminationState("false", value)
+      }
       e.foreach(subExprEliminationExprs.put(_, state))
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c7d01486/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala
new file mode 100644
index 0000000..a2dda48
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Defines util methods used in expression code generation.
+ */
+object ExpressionCodegen {
+
+  /**
+   * Given an expression, returns the all necessary parameters to evaluate it, 
so the generated
+   * code of this expression can be split in a function.
+   * The 1st string in returned tuple is the parameter strings used to call 
the function.
+   * The 2nd string in returned tuple is the parameter strings used to declare 
the function.
+   *
+   * Returns `None` if it can't produce valid parameters.
+   *
+   * Params to include:
+   * 1. Evaluated columns referred by this, children or deferred expressions.
+   * 2. Rows referred by this, children or deferred expressions.
+   * 3. Eliminated subexpressions referred by children expressions.
+   */
+  def getExpressionInputParams(
+      ctx: CodegenContext,
+      expr: Expression): Option[(Seq[String], Seq[String])] = {
+    val subExprs = getSubExprInChildren(ctx, expr)
+    val subExprCodes = getSubExprCodes(ctx, subExprs)
+    val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) 
=>
+      ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable)
+    }
+    val paramsFromSubExprs = prepareFunctionParams(ctx, subVars)
+
+    val inputVars = getInputVarsForChildren(ctx, expr)
+    val paramsFromColumns = prepareFunctionParams(ctx, inputVars)
+
+    val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr)
+    val paramsFromRows = inputRows.distinct.filter(_ != null).map { row =>
+      (row, s"InternalRow $row")
+    }
+
+    val paramsLength = getParamLength(ctx, inputVars ++ subVars) + 
paramsFromRows.length
+    // Maximum allowed parameter number for Java's method descriptor.
+    if (paramsLength > 255) {
+      None
+    } else {
+      val allParams = (paramsFromRows ++ paramsFromColumns ++ 
paramsFromSubExprs).unzip
+      val callParams = allParams._1.distinct
+      val declParams = allParams._2.distinct
+      Some((callParams, declParams))
+    }
+  }
+
+  /**
+   * Returns the eliminated subexpressions in the children expressions.
+   */
+  def getSubExprInChildren(ctx: CodegenContext, expr: Expression): 
Seq[Expression] = {
+    expr.children.flatMap { child =>
+      child.collect {
+        case e if ctx.subExprEliminationExprs.contains(e) => e
+      }
+    }.distinct
+  }
+
+  /**
+   * A small helper function to return `ExprCode`s that represent 
subexpressions.
+   */
+  def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): 
Seq[ExprCode] = {
+    subExprs.map { subExpr =>
+      val state = ctx.subExprEliminationExprs(subExpr)
+      ExprCode(code = "", value = state.value, isNull = state.isNull)
+    }
+  }
+
+  /**
+   * Retrieves previous input rows referred by children and deferred 
expressions.
+   */
+  def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): 
Seq[String] = {
+    expr.children.flatMap(getInputRows(ctx, _)).distinct
+  }
+
+  /**
+   * Given a child expression, retrieves previous input rows referred by it or 
deferred expressions
+   * which are needed to evaluate it.
+   */
+  def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = {
+    child.flatMap {
+      // An expression directly evaluates on current input row.
+      case BoundReference(ordinal, _, _) if ctx.currentVars == null ||
+          ctx.currentVars(ordinal) == null =>
+        Seq(ctx.INPUT_ROW)
+
+      // An expression which is not evaluated yet. Tracks down to find input 
rows.
+      case BoundReference(ordinal, _, _) if 
!isEvaluated(ctx.currentVars(ordinal)) =>
+        trackDownRow(ctx, ctx.currentVars(ordinal))
+
+      case _ => Seq.empty
+    }.distinct
+  }
+
+  /**
+   * Tracks down input rows referred by the generated code snippet.
+   */
+  def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = {
+    val exprCodes = mutable.Queue[ExprCode](exprCode)
+    val inputRows = mutable.ArrayBuffer.empty[String]
+
+    while (exprCodes.nonEmpty) {
+      val curExprCode = exprCodes.dequeue()
+      if (curExprCode.inputRow != null) {
+        inputRows += curExprCode.inputRow
+      }
+      curExprCode.inputVars.foreach { inputVar =>
+        if (!isEvaluated(inputVar.exprCode)) {
+          exprCodes.enqueue(inputVar.exprCode)
+        }
+      }
+    }
+    inputRows
+  }
+
+  /**
+   * Retrieves previously evaluated columns referred by children and deferred 
expressions.
+   * Returned tuple contains the list of expressions and the list of generated 
codes.
+   */
+  def getInputVarsForChildren(
+      ctx: CodegenContext,
+      expr: Expression): Seq[ExprInputVar] = {
+    expr.children.flatMap(getInputVars(ctx, _)).distinct
+  }
+
+  /**
+   * Given a child expression, retrieves previously evaluated columns referred 
by it or
+   * deferred expressions which are needed to evaluate it.
+   */
+  def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] 
= {
+    if (ctx.currentVars == null) {
+      return Seq.empty
+    }
+
+    child.flatMap {
+      // An evaluated variable.
+      case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != 
null &&
+          isEvaluated(ctx.currentVars(ordinal)) =>
+        Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable))
+
+      // An input variable which is not evaluated yet. Tracks down to find any 
evaluated variables
+      // in the expression path.
+      // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We 
need to track to
+      // "c = a + b" and see if "a" and "b" are evaluated. If they are, we 
need to return them so
+      // to include them into parameters, if not, we track down further.
+      case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null =>
+        trackDownVar(ctx, ctx.currentVars(ordinal))
+
+      case _ => Seq.empty
+    }.distinct
+  }
+
+  /**
+   * Tracks down previously evaluated columns referred by the generated code 
snippet.
+   */
+  def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] 
= {
+    val exprCodes = mutable.Queue[ExprCode](exprCode)
+    val inputVars = mutable.ArrayBuffer.empty[ExprInputVar]
+
+    while (exprCodes.nonEmpty) {
+      exprCodes.dequeue().inputVars.foreach { inputVar =>
+        if (isEvaluated(inputVar.exprCode)) {
+          inputVars += inputVar
+        } else {
+          exprCodes.enqueue(inputVar.exprCode)
+        }
+      }
+    }
+    inputVars
+  }
+
+  /**
+   * Helper function to calculate the size of an expression as function 
parameter.
+   */
+  def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = {
+    (if (input.nullable) 1 else 0) + ctx.javaType(input.dataType) match {
+      case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 2
+      case _ => 1
+    }
+  }
+
+  /**
+   * In Java, a method descriptor is valid only if it represents method 
parameters with a total
+   * length of 255 or less. `this` contributes one unit and a parameter of 
type long or double
+   * contributes two units.
+   */
+  def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = {
+    // Initial value is 1 for `this`.
+    1 + inputs.map(calculateParamLength(ctx, _)).sum
+  }
+
+  /**
+   * Given the lists of input attributes and variables to this expression, 
returns the strings of
+   * funtion parameters. The first is the variable names used to call the 
function, the second is
+   * the parameters used to declare the function in generated code.
+   */
+  def prepareFunctionParams(
+      ctx: CodegenContext,
+      inputVars: Seq[ExprInputVar]): Seq[(String, String)] = {
+    inputVars.flatMap { inputVar =>
+      val params = mutable.ArrayBuffer.empty[(String, String)]
+      val ev = inputVar.exprCode
+
+      // Only include the expression value if it is not a literal.
+      if (!isLiteral(ev)) {
+        val argType = ctx.javaType(inputVar.dataType)
+        params += ((ev.value, s"$argType ${ev.value}"))
+      }
+
+      // If it is a nullable expression and `isNull` is not a literal.
+      if (inputVar.nullable && ev.isNull != "true" && ev.isNull != "false") {
+        params += ((ev.isNull, s"boolean ${ev.isNull}"))
+      }
+
+      params
+    }.distinct
+  }
+
+  /**
+   * Only applied to the `ExprCode` in `ctx.currentVars`.
+   * Returns true if this value is a literal.
+   */
+  def isLiteral(exprCode: ExprCode): Boolean = {
+    assert(exprCode.value.nonEmpty, "ExprCode.value can't be empty string.")
+
+    if (exprCode.value == "true" || exprCode.value == "false" || 
exprCode.value == "null") {
+      true
+    } else {
+      // The valid characters for the first character of a Java variable is 
[a-zA-Z_$].
+      exprCode.value.head match {
+        case v if v >= 'a' && v <= 'z' => false
+        case v if v >= 'A' && v <= 'Z' => false
+        case '_' | '$' => false
+        case _ => true
+      }
+    }
+  }
+
+  /**
+   * Only applied to the `ExprCode` in `ctx.currentVars`.
+   * The code is emptied after evaluation.
+   */
+  def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == ""
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c7d01486/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala
new file mode 100644
index 0000000..39d58ca
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.IntegerType
+
+class ExpressionCodegenSuite extends SparkFunSuite {
+
+  test("Returns eliminated subexpressions for expression") {
+    val ctx = new CodegenContext()
+    val subExpr = Add(Literal(1), Literal(2))
+    val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4)))
+
+    ctx.generateExpressions(exprs, doSubexpressionElimination = true)
+    val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0))
+    assert(subexpressions.length == 1 && subexpressions(0) == subExpr)
+  }
+
+  test("Gets parameters for subexpressions") {
+    val ctx = new CodegenContext()
+    val subExprs = Seq(
+      Add(Literal(1), AttributeReference("a", IntegerType, nullable = 
false)()), // non-nullable
+      Add(Literal(2), AttributeReference("b", IntegerType, nullable = 
true)()))  // nullable
+
+    ctx.subExprEliminationExprs.put(subExprs(0), 
SubExprEliminationState("false", "value1"))
+    ctx.subExprEliminationExprs.put(subExprs(1), 
SubExprEliminationState("isNull2", "value2"))
+
+    val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs)
+    val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) =>
+      ExprInputVar(exprCode, expr.dataType, expr.nullable)
+    }
+    val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars)
+    assert(params.length == 3)
+    assert(params(0) == Tuple2("value1", "int value1"))
+    assert(params(1) == Tuple2("value2", "int value2"))
+    assert(params(2) == Tuple2("isNull2", "boolean isNull2"))
+  }
+
+  test("Returns input variables for expression: current variables") {
+    val ctx = new CodegenContext()
+    val currentVars = Seq(
+      ExprCode("", isNull = "false", value = "value1"),             // 
evaluated
+      ExprCode("", isNull = "isNull2", value = "value2"),           // 
evaluated
+      ExprCode("fake code;", isNull = "isNull3", value = "value3")) // not 
evaluated
+    ctx.currentVars = currentVars
+    ctx.INPUT_ROW = null
+
+    val expr = If(Literal(false),
+      Add(BoundReference(0, IntegerType, nullable = false),
+          BoundReference(1, IntegerType, nullable = true)),
+        BoundReference(2, IntegerType, nullable = true))
+
+    val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
+    // Only two evaluated variables included.
+    assert(inputVars.length == 2)
+    assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == 
false)
+    assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == 
true)
+    assert(inputVars(0).exprCode == currentVars(0))
+    assert(inputVars(1).exprCode == currentVars(1))
+
+    val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
+    assert(params.length == 3)
+    assert(params(0) == Tuple2("value1", "int value1"))
+    assert(params(1) == Tuple2("value2", "int value2"))
+    assert(params(2) == Tuple2("isNull2", "boolean isNull2"))
+  }
+
+  test("Returns input variables for expression: deferred variables") {
+    val ctx = new CodegenContext()
+
+    // The referred column is not evaluated yet. But it depends on an 
evaluated column from
+    // other operator.
+    val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = 
"value1"))
+
+    // currentVars(0) depends on this evaluated column.
+    currentVars(0).inputVars = Seq(ExprInputVar(ExprCode("", isNull = 
"isNull2", value = "value2"),
+      dataType = IntegerType, nullable = true))
+    ctx.currentVars = currentVars
+    ctx.INPUT_ROW = null
+
+    val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = 
false))
+    val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
+    assert(inputVars.length == 1)
+    assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == 
true)
+
+    val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
+    assert(params.length == 2)
+    assert(params(0) == Tuple2("value2", "int value2"))
+    assert(params(1) == Tuple2("isNull2", "boolean isNull2"))
+  }
+
+  test("Returns input rows for expression") {
+    val ctx = new CodegenContext()
+    ctx.currentVars = null
+    ctx.INPUT_ROW = "i"
+
+    val expr = Add(BoundReference(0, IntegerType, nullable = false),
+      BoundReference(1, IntegerType, nullable = true))
+    val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
+    assert(inputRows.length == 1)
+    assert(inputRows(0) == "i")
+  }
+
+  test("Returns input rows for expression: deferred expression") {
+    val ctx = new CodegenContext()
+
+    // The referred column is not evaluated yet. But it depends on an input 
row from
+    // other operator.
+    val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = 
"value1"))
+    currentVars(0).inputRow = "inputadaptor_row1"
+    ctx.currentVars = currentVars
+    ctx.INPUT_ROW = null
+
+    val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = 
false))
+    val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
+    assert(inputRows.length == 1)
+    assert(inputRows(0) == "inputadaptor_row1")
+  }
+
+  test("Returns both input rows and variables for expression") {
+    val ctx = new CodegenContext()
+    // 5 input variables in currentVars:
+    //   1 evaluated variable (value1).
+    //   3 not evaluated variables.
+    //     value2 depends on an evaluated column from other operator.
+    //     value3 depends on an input row from other operator.
+    //     value4 depends on a not evaluated yet column from other operator.
+    //   1 null indicating to use input row "i".
+    val currentVars = Seq(
+      ExprCode("", isNull = "false", value = "value1"),
+      ExprCode("fake code;", isNull = "isNull2", value = "value2"),
+      ExprCode("fake code;", isNull = "isNull3", value = "value3"),
+      ExprCode("fake code;", isNull = "isNull4", value = "value4"),
+      null)
+    // value2 depends on this evaluated column.
+    currentVars(1).inputVars = Seq(ExprInputVar(ExprCode("", isNull = 
"isNull5", value = "value5"),
+      dataType = IntegerType, nullable = true))
+    // value3 depends on an input row "inputadaptor_row1".
+    currentVars(2).inputRow = "inputadaptor_row1"
+    // value4 depends on another not evaluated yet column.
+    currentVars(3).inputVars = Seq(ExprInputVar(ExprCode("fake code;",
+      isNull = "isNull6", value = "value6"), dataType = IntegerType, nullable 
= true))
+    ctx.currentVars = currentVars
+    ctx.INPUT_ROW = "i"
+
+    // expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] }
+    val expr = If(Literal(false),
+      Add(BoundReference(0, IntegerType, nullable = false),
+          BoundReference(1, IntegerType, nullable = true)),
+      Add(Add(BoundReference(2, IntegerType, nullable = true),
+              BoundReference(3, IntegerType, nullable = true)),
+          BoundReference(4, IntegerType, nullable = true))) // this is based 
on input row "i".
+
+    // input rows: "i", "inputadaptor_row1".
+    val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
+    assert(inputRows.length == 2)
+    assert(inputRows(0) == "inputadaptor_row1")
+    assert(inputRows(1) == "i")
+
+    // input variables: value1 and value5
+    val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
+    assert(inputVars.length == 2)
+
+    // value1 has inlined isNull "false", so don't need to include it in the 
params.
+    val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, 
inputVars)
+    assert(inputVarParams.length == 3)
+    assert(inputVarParams(0) == Tuple2("value1", "int value1"))
+    assert(inputVarParams(1) == Tuple2("value5", "int value5"))
+    assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5"))
+  }
+
+  test("isLiteral: literals") {
+    val literals = Seq(
+      ExprCode("", "", "true"),
+      ExprCode("", "", "false"),
+      ExprCode("", "", "1"),
+      ExprCode("", "", "-1"),
+      ExprCode("", "", "1L"),
+      ExprCode("", "", "-1L"),
+      ExprCode("", "", "1.0f"),
+      ExprCode("", "", "-1.0f"),
+      ExprCode("", "", "0.1f"),
+      ExprCode("", "", "-0.1f"),
+      ExprCode("", "", """"string""""),
+      ExprCode("", "", "(byte)-1"),
+      ExprCode("", "", "(short)-1"),
+      ExprCode("", "", "null"))
+
+    literals.foreach(l => assert(ExpressionCodegen.isLiteral(l) == true))
+  }
+
+  test("isLiteral: non literals") {
+    val variables = Seq(
+      ExprCode("", "", "var1"),
+      ExprCode("", "", "_var2"),
+      ExprCode("", "", "$var3"),
+      ExprCode("", "", "v1a2r3"),
+      ExprCode("", "", "_1v2a3r"),
+      ExprCode("", "", "$1v2a3r"))
+
+    variables.foreach(v => assert(ExpressionCodegen.isLiteral(v) == false))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c7d01486/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index a9bfb63..05186c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -108,7 +108,10 @@ private[sql] trait ColumnarBatchScan extends 
CodegenSupport {
          |}""".stripMargin)
 
     ctx.currentVars = null
+    // `rowIdx` isn't in `ctx.currentVars`. If the expressions are split 
later, we can't track it.
+    // So making it as global variable.
     val rowidx = ctx.freshName("rowIdx")
+    ctx.addMutableState(ctx.JAVA_INT, rowidx)
     val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
       genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
     }
@@ -128,7 +131,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport 
{
        |  int $numRows = $batch.numRows();
        |  int $localEnd = $numRows - $idx;
        |  for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
-       |    int $rowidx = $idx + $localIdx;
+       |    $rowidx = $idx + $localIdx;
        |    ${consume(ctx, columnsBatchInput).trim}
        |    $shouldStop
        |  }

http://git-wip-us.apache.org/repos/asf/spark/blob/c7d01486/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index bc05dca..1281169 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.{QueryTest, Row, SaveMode}
+import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, 
CodeGenerator}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -236,4 +237,24 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSQLContext {
       }
     }
   }
+
+  test("SPARK-22551: Fix 64kb limit for deeply nested expressions under 
wholestage codegen") {
+    import testImplicits._
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+      val df = Seq(("abc", 1)).toDF("key", "int")
+      df.write.parquet(path)
+
+      var strExpr: Expression = col("key").expr
+      for (_ <- 1 to 150) {
+        strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8"))
+      }
+      val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
+
+      val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*)
+      val plan = df2.queryExecution.executedPlan
+      assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
+      df2.collect()
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to