Repository: spark Updated Branches: refs/heads/master a91784fb6 -> bd94ea4c8
[SPARK-14175][SQL] whole stage codegen interface refactor ## What changes were proposed in this pull request? 1. merge consumeChild into consume() 2. always generate code for input variables and UnsafeRow, a plan can use eight of them. ## How was this patch tested? Existing tests. Author: Davies Liu <[email protected]> Closes #11975 from davies/gen_refactor. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd94ea4c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd94ea4c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd94ea4c Branch: refs/heads/master Commit: bd94ea4c80f4fc18f4000346d7c6717539846efb Parents: a91784f Author: Davies Liu <[email protected]> Authored: Sat Mar 26 11:03:05 2016 -0700 Committer: Reynold Xin <[email protected]> Committed: Sat Mar 26 11:03:05 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/execution/ExistingRDD.scala | 3 +- .../org/apache/spark/sql/execution/Expand.scala | 2 +- .../org/apache/spark/sql/execution/Sort.scala | 26 +--- .../spark/sql/execution/WholeStageCodegen.scala | 153 +++++++------------ .../execution/aggregate/TungstenAggregate.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 4 +- .../spark/sql/execution/debug/package.scala | 2 +- .../sql/execution/joins/BroadcastHashJoin.scala | 2 +- .../org/apache/spark/sql/execution/limit.scala | 2 +- 9 files changed, 72 insertions(+), 124 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 3e2c799..815ff01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -271,7 +271,8 @@ private[sql] case class DataSourceScan( | } | }""".stripMargin) - val exprRows = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) + val exprRows = + output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable)) ctx.INPUT_ROW = row ctx.currentVars = null val columns2 = exprRows.map(_.gen(ctx)) http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 05627ba..bd23b7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -93,7 +93,7 @@ case class Expand( child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { /* * When the projections list looks like: * expr1A, exprB, expr1C http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index b4dd770..efd8760 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -98,6 +98,8 @@ case class Sort( } } + override def usedInputs: AttributeSet = AttributeSet(Seq.empty) + override def upstreams(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].upstreams() } @@ -105,8 +107,6 @@ case class Sort( // Name of sorter variable used in codegen. private var sorterVariable: String = _ - override def preferUnsafeRow: Boolean = true - override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") @@ -158,22 +158,10 @@ case class Sort( """.stripMargin.trim } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { - if (row != null) { - s"$sorterVariable.insertRow((UnsafeRow)$row);" - } else { - val colExprs = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs) - - s""" - | // Convert the input attributes to an UnsafeRow and add it to the sorter - | ${code.code} - | $sorterVariable.insertRow(${code.value}); - """.stripMargin.trim - } + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + s""" + |${row.code} + |$sorterVariable.insertRow((UnsafeRow)${row.value}); + """.stripMargin } } http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 0be0b80..1b13c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -69,11 +69,6 @@ trait CodegenSupport extends SparkPlan { protected var parent: CodegenSupport = null /** - * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume. - */ - def preferUnsafeRow: Boolean = false - - /** * Returns all the RDDs of InternalRow which generates the input rows. * * Note: right now we support up to two RDDs. @@ -114,13 +109,52 @@ trait CodegenSupport extends SparkPlan { protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent. + * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume(). */ - final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - if (input != null) { - assert(input.length == output.length) + final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { + val inputVars = + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + } else { + assert(outputVars != null) + assert(outputVars.length == output.length) + // outputVars will be used to generate the code for UnsafeRow, so we should copy them + outputVars.map(_.copy()) + } + val rowVar = if (row != null) { + ExprCode("", "false", row) + } else { + if (outputVars.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + val evaluateInputs = evaluateVariables(outputVars) + // generate the code to create a UnsafeRow + ctx.currentVars = outputVars + val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + val code = s""" + |$evaluateInputs + |${ev.code.trim} + """.stripMargin.trim + ExprCode(code, "false", ev.value) + } else { + // There is no columns + ExprCode("", "false", "unsafeRow") + } } - parent.consumeChild(ctx, this, input, row) + + ctx.freshNamePrefix = parent.variablePrefix + val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) + s""" + | + |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */ + |${evaluated} + |${parent.doConsume(ctx, inputVars, rowVar)} + """.stripMargin } /** @@ -160,47 +194,6 @@ trait CodegenSupport extends SparkPlan { def usedInputs: AttributeSet = references /** - * Consume the columns generated from its child, call doConsume() or emit the rows. - * - * An operator could generate variables for the output, or a row, either one could be null. - * - * If the row is not null, we create variables to access the columns that are actually used by - * current plan before calling doConsume(). - */ - def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - ctx.freshNamePrefix = variablePrefix - val inputVars = - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) - } - } else { - input - } - - val evaluated = - if (row != null && preferUnsafeRow) { - // Current plan can consume UnsafeRows directly. - "" - } else { - evaluateRequiredVariables(child.output, inputVars, usedInputs) - } - - s""" - | - |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */ - |${evaluated} - |${doConsume(ctx, inputVars, row)} - """.stripMargin - } - - /** * Generate the Java source code to process the rows from child SparkPlan. * * This should be override by subclass to support codegen. @@ -210,8 +203,10 @@ trait CodegenSupport extends SparkPlan { * # code to evaluate the predicate expression, result is isNull1 and value2 * if (isNull1 || !value2) continue; * # call consume(), which will call parent.doConsume() + * + * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } } @@ -245,16 +240,11 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - - val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") - ctx.INPUT_ROW = row - ctx.currentVars = null - val columns = exprs.map(_.gen(ctx)) s""" | while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); - | ${consume(ctx, columns, row).trim} + | ${consume(ctx, null, row).trim} | if (shouldStop()) return; | } """.stripMargin @@ -282,18 +272,15 @@ object WholeStageCodegen { * | * doExecute() ---------> upstreams() -------> upstreams() ------> execute() * | - * -----------------> produce() + * +-----------------> produce() * | * doProduce() -------> produce() * | * doProduce() * | - * consume() - * consumeChild() <-----------| + * doConsume() <--------- consume() * | - * doConsume() - * | - * consumeChild() <----- consume() + * doConsume() <-------- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -392,44 +379,16 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup throw new UnsupportedOperationException } - override def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val doCopy = if (ctx.copyResult) { ".copy()" } else { "" } - if (row != null) { - // There is an UnsafeRow already - s""" - |append($row$doCopy); - """.stripMargin.trim - } else { - assert(input != null) - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - val evaluateInputs = evaluateVariables(input) - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - |$evaluateInputs - |${code.code.trim} - |append(${code.value}$doCopy); - """.stripMargin.trim - } else { - // There is no columns - s""" - |append(unsafeRow); - """.stripMargin.trim - } - } + s""" + |${row.code} + |append(${row.value}$doCopy); + """.stripMargin.trim } override def innerChildren: Seq[SparkPlan] = { http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 28945a5..7c215d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -139,7 +139,7 @@ case class TungstenAggregate( } } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { if (groupingExpressions.isEmpty) { doConsumeWithoutKeys(ctx, input) } else { http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ee3f1d7..70e04d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) references.filter(a => usedMoreThanOnce.contains(a.exprId)) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -107,7 +107,7 @@ case class Filter(condition: Expression, child: SparkPlan) child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") // filter out the nulls http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index d5ce124..5e573b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -137,7 +137,7 @@ package object debug { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { consume(ctx, input) } } http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index aa2da28..f5b083c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -110,7 +110,7 @@ case class BroadcastHashJoin( streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { joinType match { case Inner => codegenInner(ctx, input) case LeftOuter | RightOuter => codegenOuter(ctx, input) http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ca624a5..9643b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val stopEarly = ctx.freshName("stopEarly") ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
