Repository: spark Updated Branches: refs/heads/master cc30ef800 -> fcf66a327
[SPARK-21657][SQL] optimize explode quadratic memory consumpation ## What changes were proposed in this pull request? The issue has been raised in two Jira tickets: [SPARK-21657](https://issues.apache.org/jira/browse/SPARK-21657), [SPARK-16998](https://issues.apache.org/jira/browse/SPARK-16998). Basically, what happens is that in collection generators like explode/inline we create many rows from each row. Currently each exploded row contains also the column on which it was created. This causes, for example, if we have a 10k array in one row that this array will get copy 10k times - to each of the row. this results a qudratic memory consumption. However, it is a common case that the original column gets projected out after the explode, so we can avoid duplicating it. In this solution we propose to identify this situation in the optimizer and turn on a flag for omitting the original column in the generation process. ## How was this patch tested? 1. We added a benchmark test to MiscBenchmark that shows x16 improvement in runtimes. 2. We ran some of the other tests in MiscBenchmark and they show 15% improvements. 3. We ran this code on a specific case from our production data with rows containing arrays of size ~200k and it reduced the runtime from 6 hours to 3 mins. Author: oraviv <ora...@paypal.com> Author: uzadude <ohad.ra...@gmail.com> Author: uzadude <15645757+uzad...@users.noreply.github.com> Closes #19683 from uzadude/optimize_explode. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fcf66a32 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fcf66a32 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fcf66a32 Branch: refs/heads/master Commit: fcf66a32760c74e601acb537c51b2311ece6e9d5 Parents: cc30ef8 Author: oraviv <ora...@paypal.com> Authored: Fri Dec 29 21:08:34 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Dec 29 21:08:34 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 6 +-- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../apache/spark/sql/catalyst/dsl/package.scala | 6 +-- .../sql/catalyst/optimizer/Optimizer.scala | 13 +++--- .../spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 21 ++++++---- .../catalyst/optimizer/ColumnPruningSuite.scala | 44 ++++++++++++-------- .../optimizer/FilterPushdownSuite.scala | 14 +++---- .../sql/catalyst/parser/PlanParserSuite.scala | 16 ++++--- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/execution/GenerateExec.scala | 29 ++++++------- .../spark/sql/execution/SparkStrategies.scala | 6 +-- .../sql/execution/benchmark/MiscBenchmark.scala | 37 ++++++++++++++++ 13 files changed, 128 insertions(+), 74 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7f2128e..1f7191c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -696,7 +696,7 @@ class Analyzer( (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) @@ -1138,7 +1138,7 @@ class Analyzer( case g: Generate => val maybeResolvedExprs = exprs.map(resolveExpression(_, g)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) - (newExprs, g.copy(join = true, child = newChild)) + (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes // via its children. @@ -1578,7 +1578,7 @@ class Analyzer( resolvedGenerator = Generate( generator, - join = projectList.size > 1, // Only join if there are other expressions in SELECT. + unrequiredChildIndex = Nil, outer = outer, qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6894aed..bbcec56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -608,8 +608,8 @@ trait CheckAnalysis extends PredicateHelper { // allows to have correlation under it // but must not host any outer references. // Note: - // Generator with join=false is treated as Category 4. - case g: Generate if g.join => + // Generator with requiredChildOutput.isEmpty is treated as Category 4. + case g: Generate if g.requiredChildOutput.nonEmpty => failOnInvalidOuterReference(g) // Category 4: Any other operators not in the above 3 categories http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7c100af..59cb26d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -359,12 +359,12 @@ package object dsl { def generate( generator: Generator, - join: Boolean = false, + unrequiredChildIndex: Seq[Int] = Nil, outer: Boolean = false, alias: Option[String] = None, outputNames: Seq[String] = Nil): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, - outputNames.map(UnresolvedAttribute(_)), logicalPlan) + Generate(generator, unrequiredChildIndex, outer, + alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6a4d1e9..eeb1b13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -456,12 +456,15 @@ object ColumnPruning extends Rule[LogicalPlan] { f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) - case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = prunedChild(g.child, g.references)) - // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => - p.copy(child = g.copy(join = false)) + // prune unrequired references + case p @ Project(_, g: Generate) if p.references != g.outputSet => + val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references + val newChild = prunedChild(g.child, requiredAttrs) + val unrequired = g.generator.references -- p.references + val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) + .map(_._2) + p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) // Eliminate unneeded attributes from right side of a Left Existence Join. case j @ Join(_, right, LeftExistence(_), _) => http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7651d11..bdc357d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -623,7 +623,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val expressions = expressionList(ctx.expression) Generate( UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), - join = true, + unrequiredChildIndex = Nil, outer = ctx.OUTER != null, Some(ctx.tblName.getText.toLowerCase), ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index cd47455..95e099c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -73,8 +73,13 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * their output. * * @param generator the generator expression - * @param join when true, each output row is implicitly joined with the input tuple that produced - * it. + * @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer. + * It's used as an optimization for omitting data generation that will + * be discarded next by a projection. + * A common use case is when we explode(array(..)) and are interested + * only in the exploded data and not in the original array. before this + * optimization the array got duplicated for each of its elements, + * causing O(n^^2) memory consumption. (see [SPARK-21657]) * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. * @param qualifier Qualifier for the attributes of generator(UDTF) @@ -83,15 +88,17 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend */ case class Generate( generator: Generator, - join: Boolean, + unrequiredChildIndex: Seq[Int], outer: Boolean, qualifier: Option[String], generatorOutput: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - /** The set of all attributes produced by this node. */ - def generatedSet: AttributeSet = AttributeSet(generatorOutput) + lazy val requiredChildOutput: Seq[Attribute] = { + val unrequiredSet = unrequiredChildIndex.toSet + child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1) + } override lazy val resolved: Boolean = { generator.resolved && @@ -114,9 +121,7 @@ case class Generate( nullableOutput } - def output: Seq[Attribute] = { - if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput - } + def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput } case class Filter(condition: Expression, child: LogicalPlan) http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 77e4eff..9f0f7e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -38,54 +38,64 @@ class ColumnPruningSuite extends PlanTest { CollapseProject) :: Nil } - test("Column pruning for Generate when Generate.join = false") { - val input = LocalRelation('a.int, 'b.array(StringType)) + test("Column pruning for Generate when Generate.unrequiredChildIndex = child.output") { + val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) - val query = input.generate(Explode('b), join = false).analyze + val query = + input + .generate(Explode('c), outputNames = "explode" :: Nil) + .select('c, 'explode) + .analyze val optimized = Optimize.execute(query) - val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze + val correctAnswer = + input + .select('c) + .generate(Explode('c), outputNames = "explode" :: Nil) + .analyze comparePlans(optimized, correctAnswer) } - test("Column pruning for Generate when Generate.join = true") { - val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) + test("Fill Generate.unrequiredChildIndex if possible") { + val input = LocalRelation('b.array(StringType)) val query = input - .generate(Explode('c), join = true, outputNames = "explode" :: Nil) - .select('a, 'explode) + .generate(Explode('b), outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) .analyze val optimized = Optimize.execute(query) val correctAnswer = input - .select('a, 'c) - .generate(Explode('c), join = true, outputNames = "explode" :: Nil) - .select('a, 'explode) + .generate(Explode('b), unrequiredChildIndex = input.output.zipWithIndex.map(_._2), + outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) .analyze comparePlans(optimized, correctAnswer) } - test("Turn Generate.join to false if possible") { - val input = LocalRelation('b.array(StringType)) + test("Another fill Generate.unrequiredChildIndex if possible") { + val input = LocalRelation('a.int, 'b.int, 'c1.string, 'c2.string) val query = input - .generate(Explode('b), join = true, outputNames = "explode" :: Nil) - .select(('explode + 1).as("result")) + .generate(Explode(CreateArray(Seq('c1, 'c2))), outputNames = "explode" :: Nil) + .select('a, 'c1, 'explode) .analyze val optimized = Optimize.execute(query) val correctAnswer = input - .generate(Explode('b), join = false, outputNames = "explode" :: Nil) - .select(('explode + 1).as("result")) + .select('a, 'c1, 'c2) + .generate(Explode(CreateArray(Seq('c1, 'c2))), + unrequiredChildIndex = Seq(2), + outputNames = "explode" :: Nil) .analyze comparePlans(optimized, correctAnswer) http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 641824e..4a23179 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -624,14 +624,14 @@ class FilterPushdownSuite extends PlanTest { test("generate: predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where(('b >= 5) && ('a > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where(('b >= 5) && ('a > 6)) - .generate(Explode('c_arr), true, false, Some("arr")).analyze + .generate(Explode('c_arr), alias = Some("arr")).analyze } comparePlans(optimized, correctAnswer) @@ -640,14 +640,14 @@ class FilterPushdownSuite extends PlanTest { test("generate: non-deterministic predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where('b >= 5) - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where('a + Rand(10).as("rnd") > 6 && 'col > 6) .analyze } @@ -659,14 +659,14 @@ class FilterPushdownSuite extends PlanTest { val generator = Explode('c_arr) val originalQuery = { testRelationWithArrayType - .generate(generator, true, false, Some("arr")) + .generate(generator, alias = Some("arr")) .where(('b >= 5) && ('c > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val referenceResult = { testRelationWithArrayType .where('b >= 5) - .generate(generator, true, false, Some("arr")) + .generate(generator, alias = Some("arr")) .where('c > 6).analyze } @@ -687,7 +687,7 @@ class FilterPushdownSuite extends PlanTest { test("generate: all conjuncts referenced generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), alias = Some("arr")) .where(('col > 6) || ('b > 5)).analyze } val optimized = Optimize.execute(originalQuery) http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index d34a83c..812bfdd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -276,7 +276,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view explode(x) expl as x", table("t") - .generate(explode, join = true, outer = false, Some("expl"), Seq("x")) + .generate(explode, alias = Some("expl"), outputNames = Seq("x")) .select(star())) // Multiple lateral views @@ -286,12 +286,12 @@ class PlanParserSuite extends AnalysisTest { |lateral view explode(x) expl |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, table("t") - .generate(explode, join = true, outer = false, Some("expl"), Seq.empty) - .generate(jsonTuple, join = true, outer = true, Some("jtup"), Seq("q", "z")) + .generate(explode, alias = Some("expl")) + .generate(jsonTuple, outer = true, alias = Some("jtup"), outputNames = Seq("q", "z")) .select(star())) // Multi-Insert lateral views. - val from = table("t1").generate(explode, join = true, outer = false, Some("expl"), Seq("x")) + val from = table("t1").generate(explode, alias = Some("expl"), outputNames = Seq("x")) assertEqual( """from t1 |lateral view explode(x) expl as x @@ -303,7 +303,7 @@ class PlanParserSuite extends AnalysisTest { |where s < 10 """.stripMargin, Union(from - .generate(jsonTuple, join = true, outer = false, Some("jtup"), Seq("q", "z")) + .generate(jsonTuple, alias = Some("jtup"), outputNames = Seq("q", "z")) .select(star()) .insertInto("t2"), from.where('s < 10).select(star()).insertInto("t3"))) @@ -312,10 +312,8 @@ class PlanParserSuite extends AnalysisTest { val expected = table("t") .generate( UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)), - join = true, - outer = false, - Some("posexpl"), - Seq("x", "y")) + alias = Some("posexpl"), + outputNames = Seq("x", "y")) .select(star()) assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 209b800..77e5712 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2095,7 +2095,7 @@ class Dataset[T] private[sql]( val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) withPlan { - Generate(generator, join = true, outer = false, + Generate(generator, unrequiredChildIndex = Nil, outer = false, qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2136,7 +2136,7 @@ class Dataset[T] private[sql]( val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil) withPlan { - Generate(generator, join = true, outer = false, + Generate(generator, unrequiredChildIndex = Nil, outer = false, qualifier = None, generatorOutput = Nil, planWithBarrier) } } http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index e1562be..0c2c4a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -47,8 +47,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * terminate(). * * @param generator the generator expression - * @param join when true, each output row is implicitly joined with the input tuple that produced - * it. + * @param requiredChildOutput required attributes from child's output * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. * @param generatorOutput the qualified output attributes of the generator of this node, which @@ -57,19 +56,13 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In */ case class GenerateExec( generator: Generator, - join: Boolean, + requiredChildOutput: Seq[Attribute], outer: Boolean, generatorOutput: Seq[Attribute], child: SparkPlan) extends UnaryExecNode with CodegenSupport { - override def output: Seq[Attribute] = { - if (join) { - child.output ++ generatorOutput - } else { - generatorOutput - } - } + override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -85,11 +78,19 @@ case class GenerateExec( val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithIndexInternal { (index, iter) => val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) - val rows = if (join) { + val rows = if (requiredChildOutput.nonEmpty) { + + val pruneChildForResult: InternalRow => InternalRow = + if (child.outputSet == AttributeSet(requiredChildOutput)) { + identity + } else { + UnsafeProjection.create(requiredChildOutput, child.output) + } + val joinedRow = new JoinedRow iter.flatMap { row => - // we should always set the left (child output) - joinedRow.withLeft(row) + // we should always set the left (required child output) + joinedRow.withLeft(pruneChildForResult(row)) val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { joinedRow.withRight(generatorNullRow) :: Nil @@ -136,7 +137,7 @@ case class GenerateExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { // Add input rows to the values when we are joining - val values = if (join) { + val values = if (requiredChildOutput.nonEmpty) { input } else { Seq.empty http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0ed7c2f..9102948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -499,10 +499,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.GlobalLimitExec(limit, planLater(child)) :: Nil case logical.Union(unionChildren) => execution.UnionExec(unionChildren.map(planLater)) :: Nil - case g @ logical.Generate(generator, join, outer, _, _, child) => + case g @ logical.Generate(generator, _, outer, _, _, child) => execution.GenerateExec( - generator, join = join, outer = outer, g.qualifiedGeneratorOutput, - planLater(child)) :: Nil + generator, g.requiredChildOutput, outer, + g.qualifiedGeneratorOutput, planLater(child)) :: Nil case _: logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index 01773c2..f039aea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -202,6 +202,42 @@ class MiscBenchmark extends BenchmarkBase { generate inline array wholestage off 6901 / 6928 2.4 411.3 1.0X generate inline array wholestage on 1001 / 1010 16.8 59.7 6.9X */ + + val M = 60000 + runBenchmark("generate big struct array", M) { + import sparkSession.implicits._ + val df = sparkSession.sparkContext.parallelize(Seq(("1", + Array.fill(M)({ + val i = math.random + (i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString) + })))).toDF("col", "arr") + + df.selectExpr("*", "expode(arr) as arr_col") + .select("col", "arr_col.*").count + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + test the impact of adding the optimization of Generate.unrequiredChildIndex, + we can see enormous improvement of x250 in this case! and it grows O(n^2). + + with Optimization ON: + + generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate big struct array wholestage off 331 / 378 0.2 5524.9 1.0X + generate big struct array wholestage on 205 / 232 0.3 3413.1 1.6X + + with Optimization OFF: + + generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate big struct array wholestage off 49697 / 51496 0.0 828277.7 1.0X + generate big struct array wholestage on 50558 / 51434 0.0 842641.6 1.0X + */ + } ignore("generate regular generator") { @@ -227,4 +263,5 @@ class MiscBenchmark extends BenchmarkBase { generate stack wholestage on 836 / 847 20.1 49.8 15.5X */ } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org