Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19813#discussion_r153848873
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 ---
    @@ -115,9 +118,240 @@ abstract class Expression extends 
TreeNode[Expression] {
         }
       }
     
    +  /**
    +   * Records current input row and variables for this expression into 
created `ExprCode`.
    +   */
    +  private def populateInputs(ctx: CodegenContext, eval: ExprCode): Unit = {
    +    if (ctx.INPUT_ROW != null) {
    +      eval.inputRow = ctx.INPUT_ROW
    +    }
    +    if (ctx.currentVars != null) {
    +      val boundRefs = this.collect {
    +        case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) 
!= null => (ordinal, b)
    +      }.toMap
    +
    +      ctx.currentVars.zipWithIndex.filter(_._1 != null).foreach { case 
(currentVar, idx) =>
    +        if (boundRefs.contains(idx)) {
    +          val inputVar = ExprInputVar(boundRefs(idx), exprCode = 
currentVar)
    +          eval.inputVars += inputVar
    +        }
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Returns the eliminated subexpressions in the children expressions.
    +   */
    +  private def getSubExprInChildren(ctx: CodegenContext): Seq[Expression] = 
{
    +    children.flatMap { child =>
    +      child.collect {
    +        case e if ctx.subExprEliminationExprs.contains(e) => e
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Given the list of eliminated subexpressions used in the children 
expressions, returns the
    +   * strings of funtion parameters. The first is the variable names used 
to call the function,
    +   * the second is the parameters used to declare the function in 
generated code.
    +   */
    +  private def getParamsForSubExprs(
    +      ctx: CodegenContext,
    +      subExprs: Seq[Expression]): (Seq[String], Seq[String]) = {
    +    subExprs.flatMap { subExpr =>
    +      val arguType = ctx.javaType(subExpr.dataType)
    +
    +      val subExprState = ctx.subExprEliminationExprs(subExpr)
    +      (subExprState.value, subExprState.isNull)
    +
    +      if (!subExpr.nullable || subExprState.isNull == "true" || 
subExprState.isNull == "false") {
    +        Seq((subExprState.value, s"$arguType ${subExprState.value}"))
    +      } else {
    +        Seq((subExprState.value, s"$arguType ${subExprState.value}"),
    +          (subExprState.isNull, s"boolean ${subExprState.isNull}"))
    +      }
    +    }.unzip
    +  }
    +
    +  /**
    +   * Retrieves previous input rows referred by children and deferred 
expressions.
    +   */
    +  private def getInputRowsForChildren(ctx: CodegenContext): Seq[String] = {
    +    children.flatMap(getInputRows(ctx, _)).distinct
    +  }
    +
    +  /**
    +   * Given a child expression, retrieves previous input rows referred by 
it or deferred expressions
    +   * which are needed to evaluate it.
    +   */
    +  private def getInputRows(ctx: CodegenContext, child: Expression): 
Seq[String] = {
    +    child.flatMap {
    +      // An expression directly evaluates on current input row.
    +      case BoundReference(ordinal, _, _) if ctx.currentVars == null ||
    +          ctx.currentVars(ordinal) == null =>
    +        Seq(ctx.INPUT_ROW)
    +
    +      // An expression which is not evaluated yet. Tracks down to find 
input rows.
    +      case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code 
!= "" =>
    +        trackDownRow(ctx, ctx.currentVars(ordinal))
    +
    +      case _ => Seq.empty
    +    }.distinct
    +  }
    +
    +  /**
    +   * Tracks down input rows referred by the generated code snippet.
    +   */
    +  private def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): 
Seq[String] = {
    +    var exprCodes: List[ExprCode] = List(exprCode)
    +    val inputRows = mutable.ArrayBuffer.empty[String]
    +
    +    while (exprCodes.nonEmpty) {
    +      exprCodes match {
    +        case first :: others =>
    +          exprCodes = others
    +          if (first.inputRow != null) {
    +            inputRows += first.inputRow
    +          }
    +          first.inputVars.foreach { inputVar =>
    +            if (inputVar.exprCode.code != "") {
    +              exprCodes = inputVar.exprCode :: exprCodes
    +            }
    +          }
    +        case _ =>
    +      }
    +    }
    +    inputRows.toSeq
    +  }
    +
    +  /**
    +   * Retrieves previously evaluated columns referred by children and 
deferred expressions.
    +   * Returned tuple contains the list of expressions and the list of 
generated codes.
    +   */
    +  private def getInputVarsForChildren(ctx: CodegenContext): 
(Seq[Expression], Seq[ExprCode]) = {
    +    children.flatMap(getInputVars(ctx, _)).distinct.unzip
    +  }
    +
    +  /**
    +   * Given a child expression, retrieves previously evaluated columns 
referred by it or
    +   * deferred expressions which are needed to evaluate it.
    +   */
    +  private def getInputVars(ctx: CodegenContext, child: Expression): 
Seq[(Expression, ExprCode)] = {
    +    if (ctx.currentVars == null) {
    +      return Seq.empty
    +    }
    +
    +    child.flatMap {
    +      // An evaluated variable.
    +      case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) 
!= null &&
    +          ctx.currentVars(ordinal).code == "" =>
    +        Seq((b, ctx.currentVars(ordinal)))
    +
    +      // An input variable which is not evaluated yet. Tracks down to find 
any evaluated variables
    +      // in the expression path.
    +      // E.g., if this expression is "d = c + 1" and "c" is not evaluated. 
We need to track to
    +      // "c = a + b" and see if "a" and "b" are evaluated. If they are, we 
need to return them so
    +      // to include them into parameters, if not, we tract down further.
    +      case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) 
!= null =>
    +        trackDownVar(ctx, ctx.currentVars(ordinal))
    +
    +      case _ => Seq.empty
    +    }.distinct
    +  }
    +
    +  /**
    +   * Tracks down previously evaluated columns referred by the generated 
code snippet.
    +   */
    +  private def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): 
Seq[(Expression, ExprCode)] = {
    +    var exprCodes: List[ExprCode] = List(exprCode)
    +    val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)]
    +
    +    while (exprCodes.nonEmpty) {
    +      exprCodes match {
    +        case first :: others =>
    +          exprCodes = others
    +          first.inputVars.foreach { inputVar =>
    +            if (inputVar.exprCode.code == "") {
    +              inputVars += ((inputVar.expr, inputVar.exprCode))
    +            } else {
    +              exprCodes = inputVar.exprCode :: exprCodes
    +            }
    +          }
    +        case _ =>
    +      }
    +    }
    +    inputVars.toSeq
    +  }
    +
    +  /**
    +   * Helper function to calculate the size of an expression as function 
parameter.
    +   */
    +  private def calculateParamLength(ctx: CodegenContext, input: 
Expression): Int = {
    +    ctx.javaType(input.dataType) match {
    +      case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2
    +      case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3
    +      case _ if !input.nullable => 1
    +      case _ => 2
    +    }
    +  }
    +
    +  /**
    +   * In Java, a method descriptor is valid only if it represents method 
parameters with a total
    +   * length of 255 or less. `this` contributes one unit and a parameter of 
type long or double
    +   * contributes two units.
    +   */
    +  private def getValidParamLength(
    +      ctx: CodegenContext,
    +      inputs: Seq[Expression],
    +      subExprs: Seq[Expression]): Int = {
    +    // Start value is 1 for `this`.
    +    inputs.foldLeft(1) { case (curLength, input) =>
    +      curLength + calculateParamLength(ctx, input)
    +    } + subExprs.foldLeft(0) { case (curLength, subExpr) =>
    +      curLength + calculateParamLength(ctx, subExpr)
    +    }
    +  }
    +
    +  /**
    +   * Given the lists of input attributes and variables to this expression, 
returns the strings of
    +   * funtion parameters. The first is the variable names used to call the 
function, the second is
    +   * the parameters used to declare the function in generated code.
    +   */
    +  private def prepareFunctionParams(
    +      ctx: CodegenContext,
    +      inputAttrs: Seq[Expression],
    +      inputVars: Seq[ExprCode]): (Seq[String], Seq[String]) = {
    +    inputAttrs.zip(inputVars).flatMap { case (input, ev) =>
    +      val arguType = ctx.javaType(input.dataType)
    --- End diff --
    
    ditto


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to