Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19813#discussion_r153848873
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
---
@@ -115,9 +118,240 @@ abstract class Expression extends
TreeNode[Expression] {
}
}
+ /**
+ * Records current input row and variables for this expression into
created `ExprCode`.
+ */
+ private def populateInputs(ctx: CodegenContext, eval: ExprCode): Unit = {
+ if (ctx.INPUT_ROW != null) {
+ eval.inputRow = ctx.INPUT_ROW
+ }
+ if (ctx.currentVars != null) {
+ val boundRefs = this.collect {
+ case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal)
!= null => (ordinal, b)
+ }.toMap
+
+ ctx.currentVars.zipWithIndex.filter(_._1 != null).foreach { case
(currentVar, idx) =>
+ if (boundRefs.contains(idx)) {
+ val inputVar = ExprInputVar(boundRefs(idx), exprCode =
currentVar)
+ eval.inputVars += inputVar
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the eliminated subexpressions in the children expressions.
+ */
+ private def getSubExprInChildren(ctx: CodegenContext): Seq[Expression] =
{
+ children.flatMap { child =>
+ child.collect {
+ case e if ctx.subExprEliminationExprs.contains(e) => e
+ }
+ }
+ }
+
+ /**
+ * Given the list of eliminated subexpressions used in the children
expressions, returns the
+ * strings of funtion parameters. The first is the variable names used
to call the function,
+ * the second is the parameters used to declare the function in
generated code.
+ */
+ private def getParamsForSubExprs(
+ ctx: CodegenContext,
+ subExprs: Seq[Expression]): (Seq[String], Seq[String]) = {
+ subExprs.flatMap { subExpr =>
+ val arguType = ctx.javaType(subExpr.dataType)
+
+ val subExprState = ctx.subExprEliminationExprs(subExpr)
+ (subExprState.value, subExprState.isNull)
+
+ if (!subExpr.nullable || subExprState.isNull == "true" ||
subExprState.isNull == "false") {
+ Seq((subExprState.value, s"$arguType ${subExprState.value}"))
+ } else {
+ Seq((subExprState.value, s"$arguType ${subExprState.value}"),
+ (subExprState.isNull, s"boolean ${subExprState.isNull}"))
+ }
+ }.unzip
+ }
+
+ /**
+ * Retrieves previous input rows referred by children and deferred
expressions.
+ */
+ private def getInputRowsForChildren(ctx: CodegenContext): Seq[String] = {
+ children.flatMap(getInputRows(ctx, _)).distinct
+ }
+
+ /**
+ * Given a child expression, retrieves previous input rows referred by
it or deferred expressions
+ * which are needed to evaluate it.
+ */
+ private def getInputRows(ctx: CodegenContext, child: Expression):
Seq[String] = {
+ child.flatMap {
+ // An expression directly evaluates on current input row.
+ case BoundReference(ordinal, _, _) if ctx.currentVars == null ||
+ ctx.currentVars(ordinal) == null =>
+ Seq(ctx.INPUT_ROW)
+
+ // An expression which is not evaluated yet. Tracks down to find
input rows.
+ case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code
!= "" =>
+ trackDownRow(ctx, ctx.currentVars(ordinal))
+
+ case _ => Seq.empty
+ }.distinct
+ }
+
+ /**
+ * Tracks down input rows referred by the generated code snippet.
+ */
+ private def trackDownRow(ctx: CodegenContext, exprCode: ExprCode):
Seq[String] = {
+ var exprCodes: List[ExprCode] = List(exprCode)
+ val inputRows = mutable.ArrayBuffer.empty[String]
+
+ while (exprCodes.nonEmpty) {
+ exprCodes match {
+ case first :: others =>
+ exprCodes = others
+ if (first.inputRow != null) {
+ inputRows += first.inputRow
+ }
+ first.inputVars.foreach { inputVar =>
+ if (inputVar.exprCode.code != "") {
+ exprCodes = inputVar.exprCode :: exprCodes
+ }
+ }
+ case _ =>
+ }
+ }
+ inputRows.toSeq
+ }
+
+ /**
+ * Retrieves previously evaluated columns referred by children and
deferred expressions.
+ * Returned tuple contains the list of expressions and the list of
generated codes.
+ */
+ private def getInputVarsForChildren(ctx: CodegenContext):
(Seq[Expression], Seq[ExprCode]) = {
+ children.flatMap(getInputVars(ctx, _)).distinct.unzip
+ }
+
+ /**
+ * Given a child expression, retrieves previously evaluated columns
referred by it or
+ * deferred expressions which are needed to evaluate it.
+ */
+ private def getInputVars(ctx: CodegenContext, child: Expression):
Seq[(Expression, ExprCode)] = {
+ if (ctx.currentVars == null) {
+ return Seq.empty
+ }
+
+ child.flatMap {
+ // An evaluated variable.
+ case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal)
!= null &&
+ ctx.currentVars(ordinal).code == "" =>
+ Seq((b, ctx.currentVars(ordinal)))
+
+ // An input variable which is not evaluated yet. Tracks down to find
any evaluated variables
+ // in the expression path.
+ // E.g., if this expression is "d = c + 1" and "c" is not evaluated.
We need to track to
+ // "c = a + b" and see if "a" and "b" are evaluated. If they are, we
need to return them so
+ // to include them into parameters, if not, we tract down further.
+ case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal)
!= null =>
+ trackDownVar(ctx, ctx.currentVars(ordinal))
+
+ case _ => Seq.empty
+ }.distinct
+ }
+
+ /**
+ * Tracks down previously evaluated columns referred by the generated
code snippet.
+ */
+ private def trackDownVar(ctx: CodegenContext, exprCode: ExprCode):
Seq[(Expression, ExprCode)] = {
+ var exprCodes: List[ExprCode] = List(exprCode)
+ val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)]
+
+ while (exprCodes.nonEmpty) {
+ exprCodes match {
+ case first :: others =>
+ exprCodes = others
+ first.inputVars.foreach { inputVar =>
+ if (inputVar.exprCode.code == "") {
+ inputVars += ((inputVar.expr, inputVar.exprCode))
+ } else {
+ exprCodes = inputVar.exprCode :: exprCodes
+ }
+ }
+ case _ =>
+ }
+ }
+ inputVars.toSeq
+ }
+
+ /**
+ * Helper function to calculate the size of an expression as function
parameter.
+ */
+ private def calculateParamLength(ctx: CodegenContext, input:
Expression): Int = {
+ ctx.javaType(input.dataType) match {
+ case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2
+ case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3
+ case _ if !input.nullable => 1
+ case _ => 2
+ }
+ }
+
+ /**
+ * In Java, a method descriptor is valid only if it represents method
parameters with a total
+ * length of 255 or less. `this` contributes one unit and a parameter of
type long or double
+ * contributes two units.
+ */
+ private def getValidParamLength(
+ ctx: CodegenContext,
+ inputs: Seq[Expression],
+ subExprs: Seq[Expression]): Int = {
+ // Start value is 1 for `this`.
+ inputs.foldLeft(1) { case (curLength, input) =>
+ curLength + calculateParamLength(ctx, input)
+ } + subExprs.foldLeft(0) { case (curLength, subExpr) =>
+ curLength + calculateParamLength(ctx, subExpr)
+ }
+ }
+
+ /**
+ * Given the lists of input attributes and variables to this expression,
returns the strings of
+ * funtion parameters. The first is the variable names used to call the
function, the second is
+ * the parameters used to declare the function in generated code.
+ */
+ private def prepareFunctionParams(
+ ctx: CodegenContext,
+ inputAttrs: Seq[Expression],
+ inputVars: Seq[ExprCode]): (Seq[String], Seq[String]) = {
+ inputAttrs.zip(inputVars).flatMap { case (input, ev) =>
+ val arguType = ctx.javaType(input.dataType)
--- End diff --
ditto
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]