This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new c51f644 [SPARK-37392][SQL] Fix the performance bug when inferring
constraints for Generate
c51f644 is described below
commit c51f6449d38d30d0bff22df895dca515898a520b
Author: Wenchen Fan <[email protected]>
AuthorDate: Wed Dec 8 13:04:40 2021 +0800
[SPARK-37392][SQL] Fix the performance bug when inferring constraints for
Generate
This is a performance regression since Spark 3.1, caused by
https://issues.apache.org/jira/browse/SPARK-32295
If you run the query in the JIRA ticket
```
Seq(
(1, "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x",
"x", "x", "x", "x", "x", "x")
).toDF()
.checkpoint() // or save and reload to truncate lineage
.createOrReplaceTempView("sub")
session.sql("""
SELECT
*
FROM
(
SELECT
EXPLODE( ARRAY( * ) ) result
FROM
(
SELECT
_1 a, _2 b, _3 c, _4 d, _5 e, _6 f, _7 g, _8 h, _9 i, _10 j, _11 k,
_12 l, _13 m, _14 n, _15 o, _16 p, _17 q, _18 r, _19 s, _20 t, _21 u
FROM
sub
)
)
WHERE
result != ''
""").show()
```
You will hit OOM. The reason is that:
1. We infer additional predicates with `Generate`. In this case, it's
`size(array(cast(_1#21 as string), _2#22, _3#23, ...) > 0`
2. Because of the cast, the `ConstantFolding` rule can't optimize this
`size(array(...))`.
3. We end up with a plan containing this part
```
+- Project [_1#21 AS a#106, _2#22 AS b#107, _3#23 AS c#108, _4#24 AS
d#109, _5#25 AS e#110, _6#26 AS f#111, _7#27 AS g#112, _8#28 AS h#113, _9#29 AS
i#114, _10#30 AS j#115, _11#31 AS k#116, _12#32 AS l#117, _13#33 AS m#118,
_14#34 AS n#119, _15#35 AS o#120, _16#36 AS p#121, _17#37 AS q#122, _18#38 AS
r#123, _19#39 AS s#124, _20#40 AS t#125, _21#41 AS u#126]
+- Filter (size(array(cast(_1#21 as string), _2#22, _3#23, _4#24,
_5#25, _6#26, _7#27, _8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34,
_15#35, _16#36, _17#37, _18#38, _19#39, _20#40, _21#41), true) > 0)
+- LogicalRDD [_1#21, _2#22, _3#23, _4#24, _5#25, _6#26, _7#27,
_8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34, _15#35, _16#36, _17#37,
_18#38, _19#39, _20#40, _21#41]
```
When calculating the constraints of the `Project`, we generate around 2^20
expressions, due to this code
```
var allConstraints = child.constraints
projectList.foreach {
case a Alias(l: Literal, _) =>
allConstraints += EqualNullSafe(a.toAttribute, l)
case a Alias(e, _) =>
// For every alias in `projectList`, replace the reference in
constraints by its attribute.
allConstraints ++= allConstraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
})
allConstraints += EqualNullSafe(e, a.toAttribute)
case _ => // Don't change.
}
```
There are 3 issues here:
1. We may infer complicated predicates from `Generate`
2. `ConstanFolding` rule is too conservative. At least `Cast` has no side
effect with ANSI-off.
3. When calculating constraints, we should have a upper bound to avoid
generating too many expressions.
This fixes the first 2 issues, and leaves the third one for the future.
fix a performance issue
no
new tests, and run the query in JIRA ticket locally.
Closes #34823 from cloud-fan/perf.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 1fac7a9d9992b7c120f325cdfa6a935b52c7f3bc)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 41 +++++----
.../spark/sql/catalyst/optimizer/expressions.scala | 1 +
.../optimizer/InferFiltersFromGenerateSuite.scala | 98 ++++++++++------------
3 files changed, 67 insertions(+), 73 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 99b5240..e39fa23 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -893,25 +893,30 @@ object TransposeWindow extends Rule[LogicalPlan] {
* by this [[Generate]] can be removed earlier - before joins and in data
sources.
*/
object InferFiltersFromGenerate extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- // This rule does not infer filters from foldable expressions to avoid
constant filters
- // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints
and
- // then the idempotence will break.
- case generate @ Generate(e, _, _, _, _, _)
- if !e.deterministic || e.children.forall(_.foldable) ||
- e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate
-
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
- // Exclude child's constraints to guarantee idempotency
- val inferredFilters = ExpressionSet(
- Seq(
- GreaterThan(Size(g.children.head), Literal(0)),
- IsNotNull(g.children.head)
- )
- ) -- generate.child.constraints
-
- if (inferredFilters.nonEmpty) {
- generate.copy(child = Filter(inferredFilters.reduce(And),
generate.child))
+ assert(g.children.length == 1)
+ val input = g.children.head
+ // Generating extra predicates here has overheads/risks:
+ // - We may evaluate expensive input expressions multiple times.
+ // - We may infer too many constraints later.
+ // - The input expression may fail to be evaluated under ANSI mode. If
we reorder the
+ // predicates and evaluate the input expression first, we may fail
the query unexpectedly.
+ // To be safe, here we only generate extra predicates if the input is an
attribute.
+ // Note that, foldable input is also excluded here, to avoid constant
filters like
+ // 'size([1, 2, 3]) > 0'. These do not show up in child's constraints
and then the
+ // idempotence will break.
+ if (input.isInstanceOf[Attribute]) {
+ // Exclude child's constraints to guarantee idempotency
+ val inferredFilters = ExpressionSet(
+ Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input))
+ ) -- generate.child.constraints
+
+ if (inferredFilters.nonEmpty) {
+ generate.copy(child = Filter(inferredFilters.reduce(And),
generate.child))
+ } else {
+ generate
+ }
} else {
generate
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index d989753..78098c4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -46,6 +46,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
private def hasNoSideEffect(e: Expression): Boolean = e match {
case _: Attribute => true
case _: Literal => true
+ case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child)
case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect)
case _ => false
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
index 800d37e..61ab4f0 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala
@@ -18,10 +18,8 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -36,7 +34,7 @@ class InferFiltersFromGenerateSuite extends PlanTest {
val testRelation = LocalRelation('a.array(StructType(Seq(
StructField("x", IntegerType),
StructField("y", IntegerType)
- ))), 'c1.string, 'c2.string)
+ ))), 'c1.string, 'c2.string, 'c3.int)
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
val generator = f('a)
@@ -74,63 +72,53 @@ class InferFiltersFromGenerateSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
- }
- // setup rules to test inferFilters with ConstantFolding to make sure
- // the Filter rule added in inferFilters is removed again when doing
- // explode with CreateArray/CreateMap
- object OptimizeInferAndConstantFold extends RuleExecutor[LogicalPlan] {
- val batches =
- Batch("AnalysisNodes", Once,
- EliminateSubqueryAliases) ::
- Batch("Infer Filters", Once, InferFiltersFromGenerate) ::
- Batch("ConstantFolding after", FixedPoint(4),
- ConstantFolding,
- NullPropagation,
- PruneFilters) :: Nil
+ val generatorWithFromJson = f(JsonToStructs(
+ ArrayType(new StructType().add("s", "string")),
+ Map.empty,
+ 'c1))
+ test("SPARK-37392: Don't infer filters from " + generatorWithFromJson) {
+ val originalQuery = testRelation.generate(generatorWithFromJson).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
+
+ val returnSchema = ArrayType(StructType(Seq(
+ StructField("x", IntegerType),
+ StructField("y", StringType)
+ )))
+ val fakeUDF = ScalaUDF(
+ (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
+ returnSchema, 'c3 :: Nil, Nil)
+ val generatorWithUDF = f(fakeUDF)
+ test("SPARK-36715: Don't infer filters from " + generatorWithUDF) {
+ val originalQuery = testRelation.generate(generatorWithUDF).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
}
Seq(Explode(_), PosExplode(_)).foreach { f =>
- val createArrayExplode = f(CreateArray(Seq('c1)))
- test("SPARK-33544: Don't infer filters from CreateArray " +
createArrayExplode) {
- val originalQuery = testRelation.generate(createArrayExplode).analyze
- val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
- comparePlans(optimized, originalQuery)
- }
- val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
- test("SPARK-33544: Don't infer filters from CreateMap " +
createMapExplode) {
- val originalQuery = testRelation.generate(createMapExplode).analyze
- val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
- comparePlans(optimized, originalQuery)
- }
- }
-
- Seq(Inline(_)).foreach { f =>
- val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
- test("SPARK-33544: Don't infer filters from CreateArray " +
createArrayStructExplode) {
- val originalQuery =
testRelation.generate(createArrayStructExplode).analyze
- val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
- comparePlans(optimized, originalQuery)
- }
- }
+ val createArrayExplode = f(CreateArray(Seq('c1)))
+ test("SPARK-33544: Don't infer filters from " + createArrayExplode) {
+ val originalQuery = testRelation.generate(createArrayExplode).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
+ val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
+ test("SPARK-33544: Don't infer filters from " + createMapExplode) {
+ val originalQuery = testRelation.generate(createMapExplode).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
+ }
- test("SPARK-36715: Don't infer filters from udf") {
- Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
- val returnSchema = ArrayType(StructType(Seq(
- StructField("x", IntegerType),
- StructField("y", StringType)
- )))
- val fakeUDF = ScalaUDF(
- (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
- returnSchema, Literal(8) :: Nil,
- Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil)
- val generator = f(fakeUDF)
- val originalQuery = OneRowRelation().generate(generator).analyze
- val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
- val correctAnswer = OneRowRelation()
- .generate(generator)
- .analyze
- comparePlans(optimized, correctAnswer)
+ Seq(Inline(_)).foreach { f =>
+ val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
+ test("SPARK-33544: Don't infer filters from " + createArrayStructExplode) {
+ val originalQuery =
testRelation.generate(createArrayStructExplode).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]