Github user nongli commented on a diff in the pull request:
https://github.com/apache/spark/pull/10735#discussion_r49790417
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---
@@ -137,6 +156,188 @@ abstract class SparkPlan extends QueryPlan[SparkPlan]
with Logging with Serializ
protected def doExecute(): RDD[InternalRow]
/**
+ * Whether this SparkPlan support whole stage codegen or not.
+ */
+ protected def supportCodeGen: Boolean = false
+
+ class CompileFailure(e: Exception) extends Exception {}
+
+ /**
+ * Returns an RDD of InternalRow that using generated code to process
them.
+ *
+ * Here is the call graph of three SparkPlan, A and B support codegen,
but C does not.
+ *
+ * SparkPlan A SparkPlan B SparkPlan C
+ * ===============================================================
+ *
+ * -> doExecute()
+ * |
+ * doCodeGen()
+ * |
+ * produce()
+ * |
+ * doProduce() --------> produce()
+ * |
+ * doProduce() -------> produce()
+ * |
+ * doProduce() (fetch
row from upstream)
+ * |
+ * consume()
+ * doConsume() ------------|
+ * |
+ * doConsume() <----- consume()
+ * |
+ * consume() (omit the rows)
+ *
+ * SparkPlan A and B should override doProduce() and doConsume().
+ *
+ * doCodeGen() will create a CodeGenContext, which will hold a list of
variables for input,
+ * used to generated code for BoundReference.
+ */
+ protected def doCodeGen(): RDD[InternalRow] = {
+ val ctx = new CodeGenContext
+ val (rdd, code) = produce(ctx, this)
+ val exprType: String = classOf[Expression].getName
+ val references = ctx.references.toArray
+ val source = s"""
+ public Object generate($exprType[] exprs) {
+ return new GeneratedIterator(exprs);
+ }
+
+ class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
+
+ private $exprType[] expressions;
+ ${ctx.declareMutableStates()}
+ private UnsafeRow unsafeRow = new UnsafeRow(${output.length});
+
+ public GeneratedIterator($exprType[] exprs) {
+ expressions = exprs;
+ ${ctx.initMutableStates()}
+ }
+
+ protected void processNext() {
+ $code
+ }
+ }
+ """
+ // try to compile, will fallback if fail
+ // println(s"${CodeFormatter.format(source)}")
+ try {
+ CodeGenerator.compile(source)
+ } catch {
+ case e: Exception =>
+ throw new CompileFailure(e)
+ }
+
+ rdd.mapPartitions { iter =>
+ val clazz = CodeGenerator.compile(source)
+ val buffer =
clazz.generate(references).asInstanceOf[BufferedRowIterator]
+ buffer.process(iter)
+ new Iterator[InternalRow] {
+ override def hasNext: Boolean = buffer.hasNext
+ override def next: InternalRow = buffer.next()
+ }
+ }
+ }
+
+ /**
+ * Which SparkPlan is calling produce() of this one. It's itself for
the first SparkPlan.
+ */
+ private var parent: SparkPlan = null
+
+ /**
+ * Returns an input RDD of InternalRow and Java source code to process
them.
+ */
+ def produce(ctx: CodeGenContext, parent: SparkPlan): (RDD[InternalRow],
String) = {
+ this.parent = parent
+ doProduce(ctx)
+ }
+
+ /**
+ * Generate the Java source code to process, should be overrided by
subclass to support codegen.
+ *
+ * doProduce() usually generate the framework, for example, aggregation
could generate this:
+ *
+ * if (!initialized) {
+ * # create a hash map, then build the aggregation hash map
+ * # call child.produce()
+ * initialized = true;
+ * }
+ * while (hashmap.hasNext()) {
+ * row = hashmap.next();
+ * # build the aggregation results
+ * # create varialbles for results
+ * # call consume(), wich will call parent.doConsume()
+ * }
+ */
+ protected def doProduce(ctx: CodeGenContext): (RDD[InternalRow], String)
= {
+ val exprs = output.zipWithIndex.map(x => new BoundReference(x._2,
x._1.dataType, true))
+ val columns = exprs.map(_.gen(ctx))
+ val code = s"""
+ | while (input.hasNext()) {
+ | InternalRow i = (InternalRow) input.next();
+ | ${columns.map(_.code).mkString("\n")}
+ | ${consume(ctx, columns)}
+ | }
+ """.stripMargin
+ (doExecute(), code)
+ }
+
+ /**
+ * Consume the columns generated from current SparkPlan, call it's
parent or create an iterator.
+ */
+ protected def consume(
+ ctx: CodeGenContext,
+ columns: Seq[GeneratedExpressionCode]): String = {
+
+ assert(columns.length == output.length)
+ // Save the generated columns, will be used to generate BoundReference
by parent SparkPlan.
+ ctx.currentVars = columns.toArray
+
+ if (parent eq this) {
+ // This is the first SparkPlan, omit the rows.
--- End diff --
omit? do you mean emit?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]