Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/19813#discussion_r153848873 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala --- @@ -115,9 +118,240 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Records current input row and variables for this expression into created `ExprCode`. + */ + private def populateInputs(ctx: CodegenContext, eval: ExprCode): Unit = { + if (ctx.INPUT_ROW != null) { + eval.inputRow = ctx.INPUT_ROW + } + if (ctx.currentVars != null) { + val boundRefs = this.collect { + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => (ordinal, b) + }.toMap + + ctx.currentVars.zipWithIndex.filter(_._1 != null).foreach { case (currentVar, idx) => + if (boundRefs.contains(idx)) { + val inputVar = ExprInputVar(boundRefs(idx), exprCode = currentVar) + eval.inputVars += inputVar + } + } + } + } + + /** + * Returns the eliminated subexpressions in the children expressions. + */ + private def getSubExprInChildren(ctx: CodegenContext): Seq[Expression] = { + children.flatMap { child => + child.collect { + case e if ctx.subExprEliminationExprs.contains(e) => e + } + } + } + + /** + * Given the list of eliminated subexpressions used in the children expressions, returns the + * strings of funtion parameters. The first is the variable names used to call the function, + * the second is the parameters used to declare the function in generated code. + */ + private def getParamsForSubExprs( + ctx: CodegenContext, + subExprs: Seq[Expression]): (Seq[String], Seq[String]) = { + subExprs.flatMap { subExpr => + val arguType = ctx.javaType(subExpr.dataType) + + val subExprState = ctx.subExprEliminationExprs(subExpr) + (subExprState.value, subExprState.isNull) + + if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { + Seq((subExprState.value, s"$arguType ${subExprState.value}")) + } else { + Seq((subExprState.value, s"$arguType ${subExprState.value}"), + (subExprState.isNull, s"boolean ${subExprState.isNull}")) + } + }.unzip + } + + /** + * Retrieves previous input rows referred by children and deferred expressions. + */ + private def getInputRowsForChildren(ctx: CodegenContext): Seq[String] = { + children.flatMap(getInputRows(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previous input rows referred by it or deferred expressions + * which are needed to evaluate it. + */ + private def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { + child.flatMap { + // An expression directly evaluates on current input row. + case BoundReference(ordinal, _, _) if ctx.currentVars == null || + ctx.currentVars(ordinal) == null => + Seq(ctx.INPUT_ROW) + + // An expression which is not evaluated yet. Tracks down to find input rows. + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code != "" => + trackDownRow(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down input rows referred by the generated code snippet. + */ + private def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { + var exprCodes: List[ExprCode] = List(exprCode) + val inputRows = mutable.ArrayBuffer.empty[String] + + while (exprCodes.nonEmpty) { + exprCodes match { + case first :: others => + exprCodes = others + if (first.inputRow != null) { + inputRows += first.inputRow + } + first.inputVars.foreach { inputVar => + if (inputVar.exprCode.code != "") { + exprCodes = inputVar.exprCode :: exprCodes + } + } + case _ => + } + } + inputRows.toSeq + } + + /** + * Retrieves previously evaluated columns referred by children and deferred expressions. + * Returned tuple contains the list of expressions and the list of generated codes. + */ + private def getInputVarsForChildren(ctx: CodegenContext): (Seq[Expression], Seq[ExprCode]) = { + children.flatMap(getInputVars(ctx, _)).distinct.unzip + } + + /** + * Given a child expression, retrieves previously evaluated columns referred by it or + * deferred expressions which are needed to evaluate it. + */ + private def getInputVars(ctx: CodegenContext, child: Expression): Seq[(Expression, ExprCode)] = { + if (ctx.currentVars == null) { + return Seq.empty + } + + child.flatMap { + // An evaluated variable. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && + ctx.currentVars(ordinal).code == "" => + Seq((b, ctx.currentVars(ordinal))) + + // An input variable which is not evaluated yet. Tracks down to find any evaluated variables + // in the expression path. + // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to + // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so + // to include them into parameters, if not, we tract down further. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + trackDownVar(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down previously evaluated columns referred by the generated code snippet. + */ + private def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(Expression, ExprCode)] = { + var exprCodes: List[ExprCode] = List(exprCode) + val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)] + + while (exprCodes.nonEmpty) { + exprCodes match { + case first :: others => + exprCodes = others + first.inputVars.foreach { inputVar => + if (inputVar.exprCode.code == "") { + inputVars += ((inputVar.expr, inputVar.exprCode)) + } else { + exprCodes = inputVar.exprCode :: exprCodes + } + } + case _ => + } + } + inputVars.toSeq + } + + /** + * Helper function to calculate the size of an expression as function parameter. + */ + private def calculateParamLength(ctx: CodegenContext, input: Expression): Int = { + ctx.javaType(input.dataType) match { + case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2 + case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3 + case _ if !input.nullable => 1 + case _ => 2 + } + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length of 255 or less. `this` contributes one unit and a parameter of type long or double + * contributes two units. + */ + private def getValidParamLength( + ctx: CodegenContext, + inputs: Seq[Expression], + subExprs: Seq[Expression]): Int = { + // Start value is 1 for `this`. + inputs.foldLeft(1) { case (curLength, input) => + curLength + calculateParamLength(ctx, input) + } + subExprs.foldLeft(0) { case (curLength, subExpr) => + curLength + calculateParamLength(ctx, subExpr) + } + } + + /** + * Given the lists of input attributes and variables to this expression, returns the strings of + * funtion parameters. The first is the variable names used to call the function, the second is + * the parameters used to declare the function in generated code. + */ + private def prepareFunctionParams( + ctx: CodegenContext, + inputAttrs: Seq[Expression], + inputVars: Seq[ExprCode]): (Seq[String], Seq[String]) = { + inputAttrs.zip(inputVars).flatMap { case (input, ev) => + val arguType = ctx.javaType(input.dataType) --- End diff -- ditto
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org