Repository: spark
Updated Branches:
  refs/heads/master cc30ef800 -> fcf66a327


[SPARK-21657][SQL] optimize explode quadratic memory consumpation

## What changes were proposed in this pull request?

The issue has been raised in two Jira tickets: 
[SPARK-21657](https://issues.apache.org/jira/browse/SPARK-21657), 
[SPARK-16998](https://issues.apache.org/jira/browse/SPARK-16998). Basically, 
what happens is that in collection generators like explode/inline we create 
many rows from each row. Currently each exploded row contains also the column 
on which it was created. This causes, for example, if we have a 10k array in 
one row that this array will get copy 10k times - to each of the row. this 
results a qudratic memory consumption. However, it is a common case that the 
original column gets projected out after the explode, so we can avoid 
duplicating it.
In this solution we propose to identify this situation in the optimizer and 
turn on a flag for omitting the original column in the generation process.

## How was this patch tested?

1. We added a benchmark test to MiscBenchmark that shows x16 improvement in 
runtimes.
2. We ran some of the other tests in MiscBenchmark and they show 15% 
improvements.
3. We ran this code on a specific case from our production data with rows 
containing arrays of size ~200k and it reduced the runtime from 6 hours to 3 
mins.

Author: oraviv <ora...@paypal.com>
Author: uzadude <ohad.ra...@gmail.com>
Author: uzadude <15645757+uzad...@users.noreply.github.com>

Closes #19683 from uzadude/optimize_explode.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fcf66a32
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fcf66a32
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fcf66a32

Branch: refs/heads/master
Commit: fcf66a32760c74e601acb537c51b2311ece6e9d5
Parents: cc30ef8
Author: oraviv <ora...@paypal.com>
Authored: Fri Dec 29 21:08:34 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Dec 29 21:08:34 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  6 +--
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  4 +-
 .../apache/spark/sql/catalyst/dsl/package.scala |  6 +--
 .../sql/catalyst/optimizer/Optimizer.scala      | 13 +++---
 .../spark/sql/catalyst/parser/AstBuilder.scala  |  2 +-
 .../plans/logical/basicLogicalOperators.scala   | 21 ++++++----
 .../catalyst/optimizer/ColumnPruningSuite.scala | 44 ++++++++++++--------
 .../optimizer/FilterPushdownSuite.scala         | 14 +++----
 .../sql/catalyst/parser/PlanParserSuite.scala   | 16 ++++---
 .../scala/org/apache/spark/sql/Dataset.scala    |  4 +-
 .../spark/sql/execution/GenerateExec.scala      | 29 ++++++-------
 .../spark/sql/execution/SparkStrategies.scala   |  6 +--
 .../sql/execution/benchmark/MiscBenchmark.scala | 37 ++++++++++++++++
 13 files changed, 128 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7f2128e..1f7191c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -696,7 +696,7 @@ class Analyzer(
           (oldVersion, oldVersion.copy(aggregateExpressions = 
newAliases(aggregateExpressions)))
 
         case oldVersion: Generate
-            if 
oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+            if 
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
           val newOutput = oldVersion.generatorOutput.map(_.newInstance())
           (oldVersion, oldVersion.copy(generatorOutput = newOutput))
 
@@ -1138,7 +1138,7 @@ class Analyzer(
           case g: Generate =>
             val maybeResolvedExprs = exprs.map(resolveExpression(_, g))
             val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child)
-            (newExprs, g.copy(join = true, child = newChild))
+            (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild))
 
           // For `Distinct` and `SubqueryAlias`, we can't recursively resolve 
and add attributes
           // via its children.
@@ -1578,7 +1578,7 @@ class Analyzer(
             resolvedGenerator =
               Generate(
                 generator,
-                join = projectList.size > 1, // Only join if there are other 
expressions in SELECT.
+                unrequiredChildIndex = Nil,
                 outer = outer,
                 qualifier = None,
                 generatorOutput = 
ResolveGenerate.makeGeneratorOutput(generator, names),

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6894aed..bbcec56 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -608,8 +608,8 @@ trait CheckAnalysis extends PredicateHelper {
       // allows to have correlation under it
       // but must not host any outer references.
       // Note:
-      // Generator with join=false is treated as Category 4.
-      case g: Generate if g.join =>
+      // Generator with requiredChildOutput.isEmpty is treated as Category 4.
+      case g: Generate if g.requiredChildOutput.nonEmpty =>
         failOnInvalidOuterReference(g)
 
       // Category 4: Any other operators not in the above 3 categories

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 7c100af..59cb26d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -359,12 +359,12 @@ package object dsl {
 
       def generate(
         generator: Generator,
-        join: Boolean = false,
+        unrequiredChildIndex: Seq[Int] = Nil,
         outer: Boolean = false,
         alias: Option[String] = None,
         outputNames: Seq[String] = Nil): LogicalPlan =
-        Generate(generator, join = join, outer = outer, alias,
-          outputNames.map(UnresolvedAttribute(_)), logicalPlan)
+        Generate(generator, unrequiredChildIndex, outer,
+          alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan)
 
       def insertInto(tableName: String, overwrite: Boolean = false): 
LogicalPlan =
         InsertIntoTable(

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 6a4d1e9..eeb1b13 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -456,12 +456,15 @@ object ColumnPruning extends Rule[LogicalPlan] {
       f.copy(child = prunedChild(child, f.references))
     case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty 
=>
       e.copy(child = prunedChild(child, e.references))
-    case g: Generate if !g.join && (g.child.outputSet -- 
g.references).nonEmpty =>
-      g.copy(child = prunedChild(g.child, g.references))
 
-    // Turn off `join` for Generate if no column from it's child is used
-    case p @ Project(_, g: Generate) if g.join && 
p.references.subsetOf(g.generatedSet) =>
-      p.copy(child = g.copy(join = false))
+    // prune unrequired references
+    case p @ Project(_, g: Generate) if p.references != g.outputSet =>
+      val requiredAttrs = p.references -- g.producedAttributes ++ 
g.generator.references
+      val newChild = prunedChild(g.child, requiredAttrs)
+      val unrequired = g.generator.references -- p.references
+      val unrequiredIndices = newChild.output.zipWithIndex.filter(t => 
unrequired.contains(t._1))
+        .map(_._2)
+      p.copy(child = g.copy(child = newChild, unrequiredChildIndex = 
unrequiredIndices))
 
     // Eliminate unneeded attributes from right side of a Left Existence Join.
     case j @ Join(_, right, LeftExistence(_), _) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 7651d11..bdc357d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -623,7 +623,7 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
     val expressions = expressionList(ctx.expression)
     Generate(
       UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
-      join = true,
+      unrequiredChildIndex = Nil,
       outer = ctx.OUTER != null,
       Some(ctx.tblName.getText.toLowerCase),
       ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index cd47455..95e099c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -73,8 +73,13 @@ case class Project(projectList: Seq[NamedExpression], child: 
LogicalPlan) extend
  * their output.
  *
  * @param generator the generator expression
- * @param join  when true, each output row is implicitly joined with the input 
tuple that produced
- *              it.
+ * @param unrequiredChildIndex this paramter starts as Nil and gets filled by 
the Optimizer.
+ *                             It's used as an optimization for omitting data 
generation that will
+ *                             be discarded next by a projection.
+ *                             A common use case is when we explode(array(..)) 
and are interested
+ *                             only in the exploded data and not in the 
original array. before this
+ *                             optimization the array got duplicated for each 
of its elements,
+ *                             causing O(n^^2) memory consumption. (see 
[SPARK-21657])
  * @param outer when true, each input row will be output at least once, even 
if the output of the
  *              given `generator` is empty.
  * @param qualifier Qualifier for the attributes of generator(UDTF)
@@ -83,15 +88,17 @@ case class Project(projectList: Seq[NamedExpression], 
child: LogicalPlan) extend
  */
 case class Generate(
     generator: Generator,
-    join: Boolean,
+    unrequiredChildIndex: Seq[Int],
     outer: Boolean,
     qualifier: Option[String],
     generatorOutput: Seq[Attribute],
     child: LogicalPlan)
   extends UnaryNode {
 
-  /** The set of all attributes produced by this node. */
-  def generatedSet: AttributeSet = AttributeSet(generatorOutput)
+  lazy val requiredChildOutput: Seq[Attribute] = {
+    val unrequiredSet = unrequiredChildIndex.toSet
+    child.output.zipWithIndex.filterNot(t => 
unrequiredSet.contains(t._2)).map(_._1)
+  }
 
   override lazy val resolved: Boolean = {
     generator.resolved &&
@@ -114,9 +121,7 @@ case class Generate(
     nullableOutput
   }
 
-  def output: Seq[Attribute] = {
-    if (join) child.output ++ qualifiedGeneratorOutput else 
qualifiedGeneratorOutput
-  }
+  def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput
 }
 
 case class Filter(condition: Expression, child: LogicalPlan)

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 77e4eff..9f0f7e1 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -38,54 +38,64 @@ class ColumnPruningSuite extends PlanTest {
       CollapseProject) :: Nil
   }
 
-  test("Column pruning for Generate when Generate.join = false") {
-    val input = LocalRelation('a.int, 'b.array(StringType))
+  test("Column pruning for Generate when Generate.unrequiredChildIndex = 
child.output") {
+    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
 
-    val query = input.generate(Explode('b), join = false).analyze
+    val query =
+      input
+        .generate(Explode('c), outputNames = "explode" :: Nil)
+        .select('c, 'explode)
+        .analyze
 
     val optimized = Optimize.execute(query)
 
-    val correctAnswer = input.select('b).generate(Explode('b), join = 
false).analyze
+    val correctAnswer =
+      input
+        .select('c)
+        .generate(Explode('c), outputNames = "explode" :: Nil)
+        .analyze
 
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Column pruning for Generate when Generate.join = true") {
-    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
+  test("Fill Generate.unrequiredChildIndex if possible") {
+    val input = LocalRelation('b.array(StringType))
 
     val query =
       input
-        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
-        .select('a, 'explode)
+        .generate(Explode('b), outputNames = "explode" :: Nil)
+        .select(('explode + 1).as("result"))
         .analyze
 
     val optimized = Optimize.execute(query)
 
     val correctAnswer =
       input
-        .select('a, 'c)
-        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
-        .select('a, 'explode)
+        .generate(Explode('b), unrequiredChildIndex = 
input.output.zipWithIndex.map(_._2),
+          outputNames = "explode" :: Nil)
+         .select(('explode + 1).as("result"))
         .analyze
 
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Turn Generate.join to false if possible") {
-    val input = LocalRelation('b.array(StringType))
+  test("Another fill Generate.unrequiredChildIndex if possible") {
+    val input = LocalRelation('a.int, 'b.int, 'c1.string, 'c2.string)
 
     val query =
       input
-        .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
-        .select(('explode + 1).as("result"))
+        .generate(Explode(CreateArray(Seq('c1, 'c2))), outputNames = "explode" 
:: Nil)
+        .select('a, 'c1, 'explode)
         .analyze
 
     val optimized = Optimize.execute(query)
 
     val correctAnswer =
       input
-        .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
-        .select(('explode + 1).as("result"))
+        .select('a, 'c1, 'c2)
+        .generate(Explode(CreateArray(Seq('c1, 'c2))),
+          unrequiredChildIndex = Seq(2),
+          outputNames = "explode" :: Nil)
         .analyze
 
     comparePlans(optimized, correctAnswer)

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 641824e..4a23179 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -624,14 +624,14 @@ class FilterPushdownSuite extends PlanTest {
   test("generate: predicate referenced no generated column") {
     val originalQuery = {
       testRelationWithArrayType
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where(('b >= 5) && ('a > 6))
     }
     val optimized = Optimize.execute(originalQuery.analyze)
     val correctAnswer = {
       testRelationWithArrayType
         .where(('b >= 5) && ('a > 6))
-        .generate(Explode('c_arr), true, false, Some("arr")).analyze
+        .generate(Explode('c_arr), alias = Some("arr")).analyze
     }
 
     comparePlans(optimized, correctAnswer)
@@ -640,14 +640,14 @@ class FilterPushdownSuite extends PlanTest {
   test("generate: non-deterministic predicate referenced no generated column") 
{
     val originalQuery = {
       testRelationWithArrayType
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6))
     }
     val optimized = Optimize.execute(originalQuery.analyze)
     val correctAnswer = {
       testRelationWithArrayType
         .where('b >= 5)
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where('a + Rand(10).as("rnd") > 6 && 'col > 6)
         .analyze
     }
@@ -659,14 +659,14 @@ class FilterPushdownSuite extends PlanTest {
     val generator = Explode('c_arr)
     val originalQuery = {
       testRelationWithArrayType
-        .generate(generator, true, false, Some("arr"))
+        .generate(generator, alias = Some("arr"))
         .where(('b >= 5) && ('c > 6))
     }
     val optimized = Optimize.execute(originalQuery.analyze)
     val referenceResult = {
       testRelationWithArrayType
         .where('b >= 5)
-        .generate(generator, true, false, Some("arr"))
+        .generate(generator, alias = Some("arr"))
         .where('c > 6).analyze
     }
 
@@ -687,7 +687,7 @@ class FilterPushdownSuite extends PlanTest {
   test("generate: all conjuncts referenced generated column") {
     val originalQuery = {
       testRelationWithArrayType
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where(('col > 6) || ('b > 5)).analyze
     }
     val optimized = Optimize.execute(originalQuery)

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index d34a83c..812bfdd 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -276,7 +276,7 @@ class PlanParserSuite extends AnalysisTest {
     assertEqual(
       "select * from t lateral view explode(x) expl as x",
       table("t")
-        .generate(explode, join = true, outer = false, Some("expl"), Seq("x"))
+        .generate(explode, alias = Some("expl"), outputNames = Seq("x"))
         .select(star()))
 
     // Multiple lateral views
@@ -286,12 +286,12 @@ class PlanParserSuite extends AnalysisTest {
         |lateral view explode(x) expl
         |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
       table("t")
-        .generate(explode, join = true, outer = false, Some("expl"), Seq.empty)
-        .generate(jsonTuple, join = true, outer = true, Some("jtup"), Seq("q", 
"z"))
+        .generate(explode, alias = Some("expl"))
+        .generate(jsonTuple, outer = true, alias = Some("jtup"), outputNames = 
Seq("q", "z"))
         .select(star()))
 
     // Multi-Insert lateral views.
-    val from = table("t1").generate(explode, join = true, outer = false, 
Some("expl"), Seq("x"))
+    val from = table("t1").generate(explode, alias = Some("expl"), outputNames 
= Seq("x"))
     assertEqual(
       """from t1
         |lateral view explode(x) expl as x
@@ -303,7 +303,7 @@ class PlanParserSuite extends AnalysisTest {
         |where s < 10
       """.stripMargin,
       Union(from
-        .generate(jsonTuple, join = true, outer = false, Some("jtup"), 
Seq("q", "z"))
+        .generate(jsonTuple, alias = Some("jtup"), outputNames = Seq("q", "z"))
         .select(star())
         .insertInto("t2"),
         from.where('s < 10).select(star()).insertInto("t3")))
@@ -312,10 +312,8 @@ class PlanParserSuite extends AnalysisTest {
     val expected = table("t")
       .generate(
         UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)),
-        join = true,
-        outer = false,
-        Some("posexpl"),
-        Seq("x", "y"))
+        alias = Some("posexpl"),
+        outputNames = Seq("x", "y"))
       .select(star())
     assertEqual(
       "select * from t lateral view posexplode(x) posexpl as x, y",

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 209b800..77e5712 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2095,7 +2095,7 @@ class Dataset[T] private[sql](
     val generator = UserDefinedGenerator(elementSchema, rowFunction, 
input.map(_.expr))
 
     withPlan {
-      Generate(generator, join = true, outer = false,
+      Generate(generator, unrequiredChildIndex = Nil, outer = false,
         qualifier = None, generatorOutput = Nil, planWithBarrier)
     }
   }
@@ -2136,7 +2136,7 @@ class Dataset[T] private[sql](
     val generator = UserDefinedGenerator(elementSchema, rowFunction, 
apply(inputColumn).expr :: Nil)
 
     withPlan {
-      Generate(generator, join = true, outer = false,
+      Generate(generator, unrequiredChildIndex = Nil, outer = false,
         qualifier = None, generatorOutput = Nil, planWithBarrier)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index e1562be..0c2c4a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -47,8 +47,7 @@ private[execution] sealed case class LazyIterator(func: () => 
TraversableOnce[In
  * terminate().
  *
  * @param generator the generator expression
- * @param join  when true, each output row is implicitly joined with the input 
tuple that produced
- *              it.
+ * @param requiredChildOutput required attributes from child's output
  * @param outer when true, each input row will be output at least once, even 
if the output of the
  *              given `generator` is empty.
  * @param generatorOutput the qualified output attributes of the generator of 
this node, which
@@ -57,19 +56,13 @@ private[execution] sealed case class LazyIterator(func: () 
=> TraversableOnce[In
  */
 case class GenerateExec(
     generator: Generator,
-    join: Boolean,
+    requiredChildOutput: Seq[Attribute],
     outer: Boolean,
     generatorOutput: Seq[Attribute],
     child: SparkPlan)
   extends UnaryExecNode with CodegenSupport {
 
-  override def output: Seq[Attribute] = {
-    if (join) {
-      child.output ++ generatorOutput
-    } else {
-      generatorOutput
-    }
-  }
+  override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
@@ -85,11 +78,19 @@ case class GenerateExec(
     val numOutputRows = longMetric("numOutputRows")
     child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val generatorNullRow = new 
GenericInternalRow(generator.elementSchema.length)
-      val rows = if (join) {
+      val rows = if (requiredChildOutput.nonEmpty) {
+
+        val pruneChildForResult: InternalRow => InternalRow =
+          if (child.outputSet == AttributeSet(requiredChildOutput)) {
+            identity
+          } else {
+            UnsafeProjection.create(requiredChildOutput, child.output)
+          }
+
         val joinedRow = new JoinedRow
         iter.flatMap { row =>
-          // we should always set the left (child output)
-          joinedRow.withLeft(row)
+          // we should always set the left (required child output)
+          joinedRow.withLeft(pruneChildForResult(row))
           val outputRows = boundGenerator.eval(row)
           if (outer && outputRows.isEmpty) {
             joinedRow.withRight(generatorNullRow) :: Nil
@@ -136,7 +137,7 @@ case class GenerateExec(
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     // Add input rows to the values when we are joining
-    val values = if (join) {
+    val values = if (requiredChildOutput.nonEmpty) {
       input
     } else {
       Seq.empty

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0ed7c2f..9102948 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -499,10 +499,10 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.GlobalLimitExec(limit, planLater(child)) :: Nil
       case logical.Union(unionChildren) =>
         execution.UnionExec(unionChildren.map(planLater)) :: Nil
-      case g @ logical.Generate(generator, join, outer, _, _, child) =>
+      case g @ logical.Generate(generator, _, outer, _, _, child) =>
         execution.GenerateExec(
-          generator, join = join, outer = outer, g.qualifiedGeneratorOutput,
-          planLater(child)) :: Nil
+          generator, g.requiredChildOutput, outer,
+          g.qualifiedGeneratorOutput, planLater(child)) :: Nil
       case _: logical.OneRowRelation =>
         execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
       case r: logical.Range =>

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
index 01773c2..f039aea 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
@@ -202,6 +202,42 @@ class MiscBenchmark extends BenchmarkBase {
     generate inline array wholestage off          6901 / 6928          2.4     
    411.3       1.0X
     generate inline array wholestage on           1001 / 1010         16.8     
     59.7       6.9X
      */
+
+    val M = 60000
+    runBenchmark("generate big struct array", M) {
+      import sparkSession.implicits._
+      val df = sparkSession.sparkContext.parallelize(Seq(("1",
+        Array.fill(M)({
+          val i = math.random
+          (i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString)
+        })))).toDF("col", "arr")
+
+      df.selectExpr("*", "expode(arr) as arr_col")
+        .select("col", "arr_col.*").count
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+    Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+    test the impact of adding the optimization of 
Generate.unrequiredChildIndex,
+    we can see enormous improvement of x250 in this case! and it grows O(n^2).
+
+    with Optimization ON:
+
+    generate big struct array:               Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
+    
------------------------------------------------------------------------------------------------
+    generate big struct array wholestage off       331 /  378          0.2     
   5524.9       1.0X
+    generate big struct array wholestage on        205 /  232          0.3     
   3413.1       1.6X
+
+    with Optimization OFF:
+
+    generate big struct array:               Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
+    
------------------------------------------------------------------------------------------------
+    generate big struct array wholestage off    49697 / 51496          0.0     
 828277.7       1.0X
+    generate big struct array wholestage on     50558 / 51434          0.0     
 842641.6       1.0X
+     */
+
   }
 
   ignore("generate regular generator") {
@@ -227,4 +263,5 @@ class MiscBenchmark extends BenchmarkBase {
     generate stack wholestage on                   836 /  847         20.1     
     49.8      15.5X
      */
   }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to