Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19813#discussion_r153852592
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
---
@@ -115,9 +118,240 @@ abstract class Expression extends
TreeNode[Expression] {
}
}
+ /**
+ * Records current input row and variables for this expression into
created `ExprCode`.
+ */
+ private def populateInputs(ctx: CodegenContext, eval: ExprCode): Unit = {
+ if (ctx.INPUT_ROW != null) {
+ eval.inputRow = ctx.INPUT_ROW
+ }
+ if (ctx.currentVars != null) {
+ val boundRefs = this.collect {
+ case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal)
!= null => (ordinal, b)
+ }.toMap
+
+ ctx.currentVars.zipWithIndex.filter(_._1 != null).foreach { case
(currentVar, idx) =>
+ if (boundRefs.contains(idx)) {
+ val inputVar = ExprInputVar(boundRefs(idx), exprCode =
currentVar)
+ eval.inputVars += inputVar
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the eliminated subexpressions in the children expressions.
+ */
+ private def getSubExprInChildren(ctx: CodegenContext): Seq[Expression] =
{
+ children.flatMap { child =>
+ child.collect {
+ case e if ctx.subExprEliminationExprs.contains(e) => e
+ }
+ }
+ }
+
+ /**
+ * Given the list of eliminated subexpressions used in the children
expressions, returns the
+ * strings of funtion parameters. The first is the variable names used
to call the function,
+ * the second is the parameters used to declare the function in
generated code.
+ */
+ private def getParamsForSubExprs(
+ ctx: CodegenContext,
+ subExprs: Seq[Expression]): (Seq[String], Seq[String]) = {
+ subExprs.flatMap { subExpr =>
+ val arguType = ctx.javaType(subExpr.dataType)
+
+ val subExprState = ctx.subExprEliminationExprs(subExpr)
+ (subExprState.value, subExprState.isNull)
+
+ if (!subExpr.nullable || subExprState.isNull == "true" ||
subExprState.isNull == "false") {
+ Seq((subExprState.value, s"$arguType ${subExprState.value}"))
+ } else {
+ Seq((subExprState.value, s"$arguType ${subExprState.value}"),
+ (subExprState.isNull, s"boolean ${subExprState.isNull}"))
+ }
+ }.unzip
+ }
+
+ /**
+ * Retrieves previous input rows referred by children and deferred
expressions.
+ */
+ private def getInputRowsForChildren(ctx: CodegenContext): Seq[String] = {
+ children.flatMap(getInputRows(ctx, _)).distinct
+ }
+
+ /**
+ * Given a child expression, retrieves previous input rows referred by
it or deferred expressions
+ * which are needed to evaluate it.
+ */
+ private def getInputRows(ctx: CodegenContext, child: Expression):
Seq[String] = {
+ child.flatMap {
+ // An expression directly evaluates on current input row.
+ case BoundReference(ordinal, _, _) if ctx.currentVars == null ||
+ ctx.currentVars(ordinal) == null =>
+ Seq(ctx.INPUT_ROW)
+
+ // An expression which is not evaluated yet. Tracks down to find
input rows.
+ case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code
!= "" =>
+ trackDownRow(ctx, ctx.currentVars(ordinal))
+
+ case _ => Seq.empty
+ }.distinct
+ }
+
+ /**
+ * Tracks down input rows referred by the generated code snippet.
+ */
+ private def trackDownRow(ctx: CodegenContext, exprCode: ExprCode):
Seq[String] = {
+ var exprCodes: List[ExprCode] = List(exprCode)
+ val inputRows = mutable.ArrayBuffer.empty[String]
+
+ while (exprCodes.nonEmpty) {
+ exprCodes match {
+ case first :: others =>
+ exprCodes = others
+ if (first.inputRow != null) {
+ inputRows += first.inputRow
+ }
+ first.inputVars.foreach { inputVar =>
+ if (inputVar.exprCode.code != "") {
+ exprCodes = inputVar.exprCode :: exprCodes
+ }
+ }
+ case _ =>
+ }
+ }
+ inputRows.toSeq
+ }
+
+ /**
+ * Retrieves previously evaluated columns referred by children and
deferred expressions.
+ * Returned tuple contains the list of expressions and the list of
generated codes.
+ */
+ private def getInputVarsForChildren(ctx: CodegenContext):
(Seq[Expression], Seq[ExprCode]) = {
+ children.flatMap(getInputVars(ctx, _)).distinct.unzip
+ }
+
+ /**
+ * Given a child expression, retrieves previously evaluated columns
referred by it or
+ * deferred expressions which are needed to evaluate it.
+ */
+ private def getInputVars(ctx: CodegenContext, child: Expression):
Seq[(Expression, ExprCode)] = {
+ if (ctx.currentVars == null) {
+ return Seq.empty
+ }
+
+ child.flatMap {
+ // An evaluated variable.
+ case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal)
!= null &&
+ ctx.currentVars(ordinal).code == "" =>
+ Seq((b, ctx.currentVars(ordinal)))
+
+ // An input variable which is not evaluated yet. Tracks down to find
any evaluated variables
+ // in the expression path.
+ // E.g., if this expression is "d = c + 1" and "c" is not evaluated.
We need to track to
+ // "c = a + b" and see if "a" and "b" are evaluated. If they are, we
need to return them so
+ // to include them into parameters, if not, we tract down further.
--- End diff --
nit: `tract` -> `track`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]