Repository: spark Updated Branches: refs/heads/master a1877f45c -> 70221903f
[SPARK-22596][SQL] set ctx.currentVars in CodegenSupport.consume ## What changes were proposed in this pull request? `ctx.currentVars` means the input variables for the current operator, which is already decided in `CodegenSupport`, we can set it there instead of `doConsume`. also add more comments to help people understand the codegen framework. After this PR, we now have a principle about setting `ctx.currentVars` and `ctx.INPUT_ROW`: 1. for non-whole-stage-codegen path, never set them. (permit some special cases like generating ordering) 2. for whole-stage-codegen `produce` path, mostly we don't need to set them, but blocking operators may need to set them for expressions that produce data from data source, sort buffer, aggregate buffer, etc. 3. for whole-stage-codegen `consume` path, mostly we don't need to set them because `currentVars` is automatically set to child input variables and `INPUT_ROW` is mostly not used. A few plans need to tweak them as they may have different inputs, or they use the input row. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenc...@databricks.com> Closes #19803 from cloud-fan/codegen. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70221903 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70221903 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70221903 Branch: refs/heads/master Commit: 70221903f54eaa0514d5d189dfb6f175a62228a8 Parents: a1877f4 Author: Wenchen Fan <wenc...@databricks.com> Authored: Fri Nov 24 21:50:30 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Fri Nov 24 21:50:30 2017 -0800 ---------------------------------------------------------------------- .../catalyst/expressions/BoundAttribute.scala | 23 +++++++++-------- .../expressions/codegen/CodeGenerator.scala | 14 +++++++--- .../sql/execution/DataSourceScanExec.scala | 14 +++++----- .../apache/spark/sql/execution/ExpandExec.scala | 3 --- .../spark/sql/execution/GenerateExec.scala | 2 -- .../sql/execution/WholeStageCodegenExec.scala | 27 +++++++++++++++----- .../sql/execution/basicPhysicalOperators.scala | 6 +---- .../apache/spark/sql/execution/objects.scala | 20 +++++---------- 8 files changed, 59 insertions(+), 50 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 7d16118..6a17a39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -59,21 +59,24 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value - val code = oev.code - oev.code = "" - ev.copy(code = code) - } else if (nullable) { - ev.copy(code = s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") + ev.copy(code = oev.code) } else { - ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") + assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") + val javaType = ctx.javaType(dataType) + val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + if (nullable) { + ev.copy(code = + s""" + |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """.stripMargin) + } else { + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + } } } } http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9df8a8d..0498e61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -134,6 +134,17 @@ class CodegenContext { } /** + * Holding the variable name of the input row of the current operator, will be used by + * `BoundReference` to generate code. + * + * Note that if `currentVars` is not null, `BoundReference` prefers `currentVars` over `INPUT_ROW` + * to generate code. If you want to make sure the generated code use `INPUT_ROW`, you need to set + * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling + * `Expression.genCode`. + */ + final var INPUT_ROW = "i" + + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. */ @@ -386,9 +397,6 @@ class CodegenContext { final val JAVA_FLOAT = "float" final val JAVA_DOUBLE = "double" - /** The variable name of the input row in generated code. */ - final var INPUT_ROW = "i" - /** * The map from a variable name to it's next ID. */ http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a477c23..747749b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -123,7 +123,7 @@ case class RowDataSourceScanExec( |while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, null).trim} + | ${consume(ctx, columnsRowInput).trim} | if (shouldStop()) return; |} """.stripMargin @@ -355,19 +355,21 @@ case class FileSourceScanExec( // PhysicalRDD always just has one input val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val exprRows = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable) - } val row = ctx.freshName("row") + ctx.INPUT_ROW = row ctx.currentVars = null - val columnsRowInput = exprRows.map(_.genCode(ctx)) + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map{ case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } val inputRow = if (needsUnsafeRowConversion) null else row s""" |while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, inputRow).trim} + | ${consume(ctx, outputVars, inputRow).trim} | if (shouldStop()) return; |} """.stripMargin http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 33849f4..a7bd5eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -133,9 +133,6 @@ case class ExpandExec( * size explosion. */ - // Set input variables - ctx.currentVars = input - // Tracks whether a column has the same output for all rows. // Size of sameOutput array should equal N. // If sameOutput(i) is true, then the i-th column has the same value for all output rows given http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index c142d3b..e1562be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -135,8 +135,6 @@ case class GenerateExec( override def needCopyResult: Boolean = true override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - ctx.currentVars = input - // Add input rows to the values when we are joining val values = if (join) { input http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 16b5706..7166b77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -108,20 +108,22 @@ trait CodegenSupport extends SparkPlan { /** * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. + * + * Note that `outputVars` and `row` can't both be null. */ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVars = - if (row != null) { + if (outputVars != null) { + assert(outputVars.length == output.length) + // outputVars will be used to generate the code for UnsafeRow, so we should copy them + outputVars.map(_.copy()) + } else { + assert(row != null, "outputVars and row cannot both be null.") ctx.currentVars = null ctx.INPUT_ROW = row output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable).genCode(ctx) } - } else { - assert(outputVars != null) - assert(outputVars.length == output.length) - // outputVars will be used to generate the code for UnsafeRow, so we should copy them - outputVars.map(_.copy()) } val rowVar = if (row != null) { @@ -147,6 +149,11 @@ trait CodegenSupport extends SparkPlan { } } + // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` + // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to + // generate code of `rowVar` manually. + ctx.currentVars = inputVars + ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) s""" @@ -193,7 +200,8 @@ trait CodegenSupport extends SparkPlan { def usedInputs: AttributeSet = references /** - * Generate the Java source code to process the rows from child SparkPlan. + * Generate the Java source code to process the rows from child SparkPlan. This should only be + * called from `consume`. * * This should be override by subclass to support codegen. * @@ -207,6 +215,11 @@ trait CodegenSupport extends SparkPlan { * } * * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). + * When consuming as a listing of variables, the code to produce the input is already + * generated and `CodegenContext.currentVars` is already set. When consuming as UnsafeRow, + * implementations need to put `row.code` in the generated code and set + * `CodegenContext.INPUT_ROW` manually. Some plans may need more tweaks as they have + * different inputs(join build side, aggregate buffer, etc.), or other special cases. */ def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f205bdf..c9a1514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -56,9 +56,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val exprs = projectList.map(x => - ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) - ctx.currentVars = input + val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output)) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) @@ -152,8 +150,6 @@ case class FilterExec(condition: Expression, child: SparkPlan) """.stripMargin } - ctx.currentVars = input - // To generate the predicates we will follow this algorithm. // For each predicate that is not IsNotNull, we will generate them one by one loading attributes // as necessary. For each of both attributes, if there is an IsNotNull predicate we will http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d861109..d1bd8a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -81,11 +81,8 @@ case class DeserializeToObjectExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(deserializer, child.output)) - ctx.currentVars = input - val resultVars = bound.genCode(ctx) :: Nil - consume(ctx, resultVars) + val resultObj = BindReferences.bindReference(deserializer, child.output).genCode(ctx) + consume(ctx, resultObj :: Nil) } override protected def doExecute(): RDD[InternalRow] = { @@ -118,11 +115,9 @@ case class SerializeFromObjectExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val bound = serializer.map { expr => - ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) + val resultVars = serializer.map { expr => + BindReferences.bindReference[Expression](expr, child.output).genCode(ctx) } - ctx.currentVars = input - val resultVars = bound.map(_.genCode(ctx)) consume(ctx, resultVars) } @@ -224,12 +219,9 @@ case class MapElementsExec( val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) - val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(callFunc, child.output)) - ctx.currentVars = input - val resultVars = bound.genCode(ctx) :: Nil + val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) - consume(ctx, resultVars) + consume(ctx, result :: Nil) } override protected def doExecute(): RDD[InternalRow] = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org