Repository: spark
Updated Branches:
  refs/heads/master ad0de99f3 -> b373a8886


[SPARK-13415][SQL] Visualize subquery in SQL web UI

## What changes were proposed in this pull request?

This PR support visualization for subquery in SQL web UI, also improve the 
explain of subquery, especially when it's used together with whole stage 
codegen.

For example:
```python
>>> sqlContext.range(100).registerTempTable("range")
>>> sqlContext.sql("select id / (select sum(id) from range) from range where id 
>>> > (select id from range limit 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias(('id / subquery#9), None)]
:  +- 'SubqueryAlias subquery#9
:     +- 'Project [unresolvedalias('sum('id), None)]
:        +- 'UnresolvedRelation `range`, None
+- 'Filter ('id > subquery#8)
   :  +- 'SubqueryAlias subquery#8
   :     +- 'GlobalLimit 1
   :        +- 'LocalLimit 1
   :           +- 'Project [unresolvedalias('id, None)]
   :              +- 'UnresolvedRelation `range`, None
   +- 'UnresolvedRelation `range`, None

== Analyzed Logical Plan ==
(id / scalarsubquery()): double
Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / 
scalarsubquery())#11]
:  +- SubqueryAlias subquery#9
:     +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L]
:        +- SubqueryAlias range
:           +- Range 0, 100, 1, 4, [id#0L]
+- Filter (id#0L > subquery#8)
   :  +- SubqueryAlias subquery#8
   :     +- GlobalLimit 1
   :        +- LocalLimit 1
   :           +- Project [id#0L]
   :              +- SubqueryAlias range
   :                 +- Range 0, 100, 1, 4, [id#0L]
   +- SubqueryAlias range
      +- Range 0, 100, 1, 4, [id#0L]

== Optimized Logical Plan ==
Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / 
scalarsubquery())#11]
:  +- SubqueryAlias subquery#9
:     +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L]
:        +- Range 0, 100, 1, 4, [id#0L]
+- Filter (id#0L > subquery#8)
   :  +- SubqueryAlias subquery#8
   :     +- GlobalLimit 1
   :        +- LocalLimit 1
   :           +- Project [id#0L]
   :              +- Range 0, 100, 1, 4, [id#0L]
   +- Range 0, 100, 1, 4, [id#0L]

== Physical Plan ==
WholeStageCodegen
:  +- Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / 
scalarsubquery())#11]
:     :  +- Subquery subquery#9
:     :     +- WholeStageCodegen
:     :        :  +- TungstenAggregate(key=[], 
functions=[(sum(id#0L),mode=Final,isDistinct=false)], output=[sum(id)#10L])
:     :        :     +- INPUT
:     :        +- Exchange SinglePartition, None
:     :           +- WholeStageCodegen
:     :              :  +- TungstenAggregate(key=[], 
functions=[(sum(id#0L),mode=Partial,isDistinct=false)], output=[sum#14L])
:     :              :     +- Range 0, 1, 4, 100, [id#0L]
:     +- Filter (id#0L > subquery#8)
:        :  +- Subquery subquery#8
:        :     +- CollectLimit 1
:        :        +- WholeStageCodegen
:        :           :  +- Project [id#0L]
:        :           :     +- Range 0, 1, 4, 100, [id#0L]
:        +- Range 0, 1, 4, 100, [id#0L]
```

The web UI looks like:

![subquery](https://cloud.githubusercontent.com/assets/40902/13377963/932bcbae-dda7-11e5-82f7-03c9be85d77c.png)

This PR also change the tree structure of WholeStageCodegen to make it 
consistent than others. Before this change, Both WholeStageCodegen and 
InputAdapter hold a references to the same plans, those could be updated 
without notify another, causing problems, this is discovered by #11403 .

## How was this patch tested?

Existing tests, also manual tests with the example query, check the explain and 
web UI.

Author: Davies Liu <dav...@databricks.com>

Closes #11417 from davies/viz_subquery.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b373a888
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b373a888
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b373a888

Branch: refs/heads/master
Commit: b373a888621ba6f0dd499f47093d4e2e42086dfc
Parents: ad0de99
Author: Davies Liu <dav...@databricks.com>
Authored: Thu Mar 3 17:36:48 2016 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Mar 3 17:36:48 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/plans/QueryPlan.scala    |  10 +-
 .../spark/sql/catalyst/trees/TreeNode.scala     |  49 ++++++++
 .../spark/sql/execution/SparkPlanInfo.scala     |   7 +-
 .../spark/sql/execution/WholeStageCodegen.scala | 113 ++++++++-----------
 .../spark/sql/execution/debug/package.scala     |  23 +++-
 .../spark/sql/execution/ui/SparkPlanGraph.scala |  66 ++++++-----
 .../sql/execution/WholeStageCodegenSuite.scala  |  17 +--
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   |   6 +-
 .../spark/sql/util/DataFrameCallbackSuite.scala |   2 +-
 9 files changed, 166 insertions(+), 127 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 3ff37ff..0e0453b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -229,8 +229,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
 
   override def simpleString: String = statePrefix + super.simpleString
 
-  override def treeChildren: Seq[PlanType] = {
-    val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression 
=> e})
-    children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType])
+  /**
+   * All the subqueries of current plan.
+   */
+  def subqueries: Seq[PlanType] = {
+    expressions.flatMap(_.collect {case e: SubqueryExpression => 
e.plan.asInstanceOf[PlanType]})
   }
+
+  override def innerChildren: Seq[PlanType] = subqueries
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 2d0bf6b..6b7997e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -447,10 +447,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
 
   /**
    * All the nodes that will be used to generate tree string.
+   *
+   * For example:
+   *
+   *   WholeStageCodegen
+   *   +-- SortMergeJoin
+   *       |-- InputAdapter
+   *       |   +-- Sort
+   *       +-- InputAdapter
+   *           +-- Sort
+   *
+   * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will 
generate a tree string
+   * like this:
+   *
+   *   WholeStageCodegen
+   *   : +- SortMergeJoin
+   *   :    :- INPUT
+   *   :    :- INPUT
+   *   :-  Sort
+   *   :-  Sort
    */
   protected def treeChildren: Seq[BaseType] = children
 
   /**
+   * All the nodes that are parts of this node.
+   *
+   * For example:
+   *
+   *   WholeStageCodegen
+   *   +- SortMergeJoin
+   *      |-- InputAdapter
+   *      |   +-- Sort
+   *      +-- InputAdapter
+   *          +-- Sort
+   *
+   * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it 
will generate a tree
+   * string like this:
+   *
+   *   WholeStageCodegen
+   *   : +- SortMergeJoin
+   *   :    :- INPUT
+   *   :    :- INPUT
+   *   :-  Sort
+   *   :-  Sort
+   */
+  protected def innerChildren: Seq[BaseType] = Nil
+
+  /**
    * Appends the string represent of this node and its children to the given 
StringBuilder.
    *
    * The `i`-th element in `lastChildren` indicates whether the ancestor of 
the current node at
@@ -472,6 +515,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
     builder.append(simpleString)
     builder.append("\n")
 
+    if (innerChildren.nonEmpty) {
+      innerChildren.init.foreach(_.generateTreeString(
+        depth + 2, lastChildren :+ false :+ false, builder))
+      innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false 
:+ true, builder)
+    }
+
     if (treeChildren.nonEmpty) {
       treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren 
:+ false, builder))
       treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, 
builder)

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 4dd9928..9019e5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -36,11 +36,8 @@ class SparkPlanInfo(
 private[sql] object SparkPlanInfo {
 
   def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
-    val children = plan match {
-      case WholeStageCodegen(child, _) => child :: Nil
-      case InputAdapter(child) => child :: Nil
-      case plan => plan.children
-    }
+
+    val children = plan.children ++ plan.subqueries
     val metrics = plan.metrics.toSeq.map { case (key, metric) =>
       new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
         Utils.getFormattedClassName(metric.param))

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index cb68ca6..6d231bf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution
 
-import scala.collection.mutable.ArrayBuffer
-
 import org.apache.spark.broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
@@ -29,7 +27,7 @@ import 
org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.toCommentSafeString
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, 
BuildRight, SortMergeJoin}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
 import org.apache.spark.sql.execution.metric.LongSQLMetricValue
 
 /**
@@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan {
   * This is the leaf node of a tree with WholeStageCodegen, is used to 
generate code that consumes
   * an RDD iterator of InternalRow.
   */
-case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport 
{
+case class InputAdapter(child: SparkPlan) extends UnaryNode with 
CodegenSupport {
 
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = child.outputPartitioning
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 
-  override def doPrepare(): Unit = {
-    child.prepare()
-  }
-
   override def doExecute(): RDD[InternalRow] = {
     child.execute()
   }
@@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode 
with CodegenSupport {
     child.doExecuteBroadcast()
   }
 
-  override def supportCodegen: Boolean = false
-
   override def upstreams(): Seq[RDD[InternalRow]] = {
     child.execute() :: Nil
   }
@@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode 
with CodegenSupport {
   }
 
   override def simpleString: String = "INPUT"
+
+  override def treeChildren: Seq[SparkPlan] = Nil
 }
 
 /**
@@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends 
LeafNode with CodegenSupport {
   * doCodeGen() will create a CodeGenContext, which will hold a list of 
variables for input,
   * used to generated code for BoundReference.
   */
-case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
-  extends SparkPlan with CodegenSupport {
-
-  override def supportCodegen: Boolean = false
-
-  override def output: Seq[Attribute] = plan.output
-  override def outputPartitioning: Partitioning = plan.outputPartitioning
-  override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with 
CodegenSupport {
 
-  override def doPrepare(): Unit = {
-    plan.prepare()
-  }
+  override def output: Seq[Attribute] = child.output
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 
   override def doExecute(): RDD[InternalRow] = {
     val ctx = new CodegenContext
-    val code = plan.produce(ctx, this)
+    val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
     val references = ctx.references.toArray
     val source = s"""
       public Object generate(Object[] references) {
@@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, 
children: Seq[SparkPlan])
       }
 
       /** Codegened pipeline for:
-        * ${toCommentSafeString(plan.treeString.trim)}
+        * ${toCommentSafeString(child.treeString.trim)}
         */
       class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
 
@@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, 
children: Seq[SparkPlan])
     // println(s"${CodeFormatter.format(cleanedSource)}")
     CodeGenerator.compile(cleanedSource)
 
-    val rdds = plan.upstreams()
+    val rdds = child.asInstanceOf[CodegenSupport].upstreams()
     assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
     if (rdds.length == 1) {
       rdds.head.mapPartitions { iter =>
@@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, 
children: Seq[SparkPlan])
     }
   }
 
-  private[sql] override def resetMetrics(): Unit = {
-    plan.foreach(_.resetMetrics())
+  override def innerChildren: Seq[SparkPlan] = {
+    child :: Nil
   }
 
-  override def generateTreeString(
-      depth: Int,
-      lastChildren: Seq[Boolean],
-      builder: StringBuilder): StringBuilder = {
-    if (depth > 0) {
-      lastChildren.init.foreach { isLast =>
-        val prefixFragment = if (isLast) "   " else ":  "
-        builder.append(prefixFragment)
-      }
-
-      val branch = if (lastChildren.last) "+- " else ":- "
-      builder.append(branch)
-    }
-
-    builder.append(simpleString)
-    builder.append("\n")
-
-    plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
-    if (children.nonEmpty) {
-      children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ 
false, builder))
-      children.last.generateTreeString(depth + 1, lastChildren :+ true, 
builder)
-    }
+  private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
+    case InputAdapter(c) => c :: Nil
+    case other => other.children.flatMap(collectInputs)
+  }
 
-    builder
+  override def treeChildren: Seq[SparkPlan] = {
+    collectInputs(child)
   }
 
   override def simpleString: String = "WholeStageCodegen"
@@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: 
SQLContext) extends Ru
     case _ => false
   }
 
+  /**
+   * Inserts a InputAdapter on top of those that do not support codegen.
+   */
+  private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
+    case j @ SortMergeJoin(_, _, _, left, right) =>
+      // The children of SortMergeJoin should do codegen separately.
+      j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
+        right = InputAdapter(insertWholeStageCodegen(right)))
+    case p if !supportCodegen(p) =>
+      // collapse them recursively
+      InputAdapter(insertWholeStageCodegen(p))
+    case p =>
+      p.withNewChildren(p.children.map(insertInputAdapter))
+  }
+
+  /**
+   * Inserts a WholeStageCodegen on top of those that support codegen.
+   */
+  private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match 
{
+    case plan: CodegenSupport if supportCodegen(plan) =>
+      WholeStageCodegen(insertInputAdapter(plan))
+    case other =>
+      other.withNewChildren(other.children.map(insertWholeStageCodegen))
+  }
+
   def apply(plan: SparkPlan): SparkPlan = {
     if (sqlContext.conf.wholeStageEnabled) {
-      plan.transform {
-        case plan: CodegenSupport if supportCodegen(plan) =>
-          var inputs = ArrayBuffer[SparkPlan]()
-          val combined = plan.transform {
-            // The build side can't be compiled together
-            case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
-              b.copy(left = apply(left))
-            case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
-              b.copy(right = apply(right))
-            case j @ SortMergeJoin(_, _, _, left, right) =>
-              // The children of SortMergeJoin should do codegen separately.
-              j.copy(left = apply(left), right = apply(right))
-            case p if !supportCodegen(p) =>
-              val input = apply(p)  // collapse them recursively
-              inputs += input
-              InputAdapter(input)
-          }.asInstanceOf[CodegenSupport]
-          WholeStageCodegen(combined, inputs)
-      }
+      insertWholeStageCodegen(plan)
     } else {
       plan
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 95d033b..fed88b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.trees.TreeNodeRef
 import org.apache.spark.sql.internal.SQLConf
 
@@ -68,7 +69,7 @@ package object debug {
     }
   }
 
-  private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
+  private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with 
CodegenSupport {
     def output: Seq[Attribute] = child.output
 
     implicit object SetAccumulatorParam extends 
AccumulatorParam[HashSet[String]] {
@@ -86,10 +87,11 @@ package object debug {
     /**
      * A collection of metrics for each column of output.
      * @param elementTypes the actual runtime types for the output.  Useful 
when there are bugs
-     *        causing the wrong data to be projected.
+     *                     causing the wrong data to be projected.
      */
     case class ColumnMetrics(
-        elementTypes: Accumulator[HashSet[String]] = 
sparkContext.accumulator(HashSet.empty))
+      elementTypes: Accumulator[HashSet[String]] = 
sparkContext.accumulator(HashSet.empty))
+
     val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)
 
     val numColumns: Int = child.output.size
@@ -98,7 +100,7 @@ package object debug {
     def dumpStats(): Unit = {
       logDebug(s"== ${child.simpleString} ==")
       logDebug(s"Tuples output: ${tupleCount.value}")
-      child.output.zip(columnStats).foreach { case(attr, metric) =>
+      child.output.zip(columnStats).foreach { case (attr, metric) =>
         val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
         logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
       }
@@ -108,6 +110,7 @@ package object debug {
       child.execute().mapPartitions { iter =>
         new Iterator[InternalRow] {
           def hasNext: Boolean = iter.hasNext
+
           def next(): InternalRow = {
             val currentRow = iter.next()
             tupleCount += 1
@@ -124,5 +127,17 @@ package object debug {
         }
       }
     }
+
+    override def upstreams(): Seq[RDD[InternalRow]] = {
+      child.asInstanceOf[CodegenSupport].upstreams()
+    }
+
+    override def doProduce(ctx: CodegenContext): String = {
+      child.asInstanceOf[CodegenSupport].produce(ctx, this)
+    }
+
+    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String 
= {
+      consume(ctx, input)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 4eb2485..12e586a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, 
WholeStageCodegen}
+import org.apache.spark.sql.execution.SparkPlanInfo
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
@@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph {
       edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
       parent: SparkPlanGraphNode,
       subgraph: SparkPlanGraphCluster): Unit = {
-    if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) {
-      val cluster = new SparkPlanGraphCluster(
-        nodeIdGenerator.getAndIncrement(),
-        planInfo.nodeName,
-        planInfo.simpleString,
-        mutable.ArrayBuffer[SparkPlanGraphNode]())
-      nodes += cluster
-      buildSparkPlanGraphNode(
-        planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
-    } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) {
-      buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, 
edges, parent, null)
-    } else {
-      val metrics = planInfo.metrics.map { metric =>
-        SQLPlanMetric(metric.name, metric.accumulatorId,
-          SQLMetrics.getMetricParam(metric.metricParam))
-      }
-      val node = new SparkPlanGraphNode(
-        nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
-        planInfo.simpleString, planInfo.metadata, metrics)
-      if (subgraph == null) {
-        nodes += node
-      } else {
-        subgraph.nodes += node
-      }
-
-      if (parent != null) {
-        edges += SparkPlanGraphEdge(node.id, parent.id)
-      }
-      planInfo.children.foreach(
-        buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, 
subgraph))
+    planInfo.nodeName match {
+      case "WholeStageCodegen" =>
+        val cluster = new SparkPlanGraphCluster(
+          nodeIdGenerator.getAndIncrement(),
+          planInfo.nodeName,
+          planInfo.simpleString,
+          mutable.ArrayBuffer[SparkPlanGraphNode]())
+        nodes += cluster
+        buildSparkPlanGraphNode(
+          planInfo.children.head, nodeIdGenerator, nodes, edges, parent, 
cluster)
+      case "InputAdapter" =>
+        buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, 
nodes, edges, parent, null)
+      case "Subquery" if subgraph != null =>
+        // Subquery should not be included in WholeStageCodegen
+        buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, 
parent, null)
+      case _ =>
+        val metrics = planInfo.metrics.map { metric =>
+          SQLPlanMetric(metric.name, metric.accumulatorId,
+            SQLMetrics.getMetricParam(metric.metricParam))
+        }
+        val node = new SparkPlanGraphNode(
+          nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
+          planInfo.simpleString, planInfo.metadata, metrics)
+        if (subgraph == null) {
+          nodes += node
+        } else {
+          subgraph.nodes += node
+        }
+
+        if (parent != null) {
+          edges += SparkPlanGraphEdge(node.id, parent.id)
+        }
+        planInfo.children.foreach(
+          buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, 
subgraph))
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index de371d8..e00c762 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with 
SharedSQLContext {
     val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1")
     val plan = df.queryExecution.executedPlan
     assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined)
-
-    checkThatPlansAgree(
-      sqlContext.range(100),
-      (p: SparkPlan) =>
-        WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()),
-      (p: SparkPlan) => Filter('a == 1, p),
-      sortAnswers = false
-    )
+    assert(df.collect() === Array(Row(2)))
   }
 
   test("Aggregate should be included in WholeStageCodegen") {
@@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with 
SharedSQLContext {
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+        
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
     assert(df.collect() === Array(Row(9, 4.5)))
   }
 
@@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with 
SharedSQLContext {
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+        
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
     assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
   }
 
@@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with 
SharedSQLContext {
     val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === 
col("id"))
     assert(df.queryExecution.executedPlan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
+        
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined)
     assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, 
"2")))
   }
 
@@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with 
SharedSQLContext {
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
     assert(df.collect() === Array(Row(1), Row(2), Row(3)))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 4358c7c..b0d64aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -210,8 +210,8 @@ class JDBCSuite extends SparkFunSuite
       // the plan only has PhysicalRDD to scan JDBCRelation.
       
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
       val node = 
parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
-      
assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
-      
assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
+      
assert(node.child.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
+      
assert(node.child.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
       df
     }
     assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 
1")).collect().size == 0)
@@ -248,7 +248,7 @@ class JDBCSuite extends SparkFunSuite
       // cannot compile given predicates.
       
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
       val node = 
parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
-      assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter])
+      assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter])
       df
     }
     assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 
2")).collect().size == 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/b373a888/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 15a9562..e7d2b5a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
 
       override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
         val metric = qe.executedPlan match {
-          case w: WholeStageCodegen => w.plan.longMetric("numOutputRows")
+          case w: WholeStageCodegen => w.child.longMetric("numOutputRows")
           case other => other.longMetric("numOutputRows")
         }
         metrics += metric.value.value


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to