This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4b9343593ec [SPARK-38531][SQL] Fix the condition of "Prune unrequired
child index" branch of ColumnPruning
4b9343593ec is described below
commit 4b9343593eca780ca30ffda45244a71413577884
Author: minyyy <[email protected]>
AuthorDate: Fri Apr 8 10:43:38 2022 +0800
[SPARK-38531][SQL] Fix the condition of "Prune unrequired child index"
branch of ColumnPruning
### What changes were proposed in this pull request?
The "prune unrequired references" branch has the condition:
`case p Project(_, g: Generate) if p.references != g.outputSet => `
This is wrong as generators like Inline will always enter this branch as
long as it does not use all the generator output.
Example:
input: <col1: array<struct<a: struct<a: int, b: int>, b: int>>>
Project(a.a as x)
\- Generate(Inline(col1), ..., a, b)
p.references is [a]
g.outputSet is [a, b]
This bug makes us never enter the GeneratorNestedColumnAliasing branch
below thus miss some optimization opportunities. This PR changes the condition
to check whether the child output is not used by the project and it is either
not used by the generator or not already put into unrequiredChildOutput.
### Why are the changes needed?
The wrong condition prevents some expressions like Inline, PosExplode from
being optimized by rules after it. Before the PR, the test query added in the
PR is not optimized since the optimization rule is not able to apply to it.
After the PR the optimization rule can be correctly applied to the query.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests.
Closes #35864 from minyyy/gnca_wrong_cond.
Authored-by: minyyy <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/optimizer/NestedColumnAliasing.scala | 19 +++++++++++++
.../spark/sql/catalyst/optimizer/Optimizer.scala | 15 +++++-----
.../catalyst/optimizer/ColumnPruningSuite.scala | 32 ++++++++++++++++++++++
3 files changed, 58 insertions(+), 8 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 4c7130e51e0..9cf2925cdd2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -312,6 +312,25 @@ object NestedColumnAliasing {
}
}
+object GeneratorUnrequiredChildrenPruning {
+ def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+ case p @ Project(_, g: Generate) =>
+ val requiredAttrs = p.references ++ g.generator.references
+ val newChild = ColumnPruning.prunedChild(g.child, requiredAttrs)
+ val unrequired = g.generator.references -- p.references
+ val unrequiredIndices = newChild.output.zipWithIndex.filter(t =>
unrequired.contains(t._1))
+ .map(_._2)
+ if (!newChild.fastEquals(g.child) ||
+ unrequiredIndices.toSet != g.unrequiredChildIndex.toSet) {
+ Some(p.copy(child = g.copy(child = newChild, unrequiredChildIndex =
unrequiredIndices)))
+ } else {
+ None
+ }
+ case _ => None
+ }
+}
+
+
/**
* This prunes unnecessary nested columns from [[Generate]], or [[Project]] ->
[[Generate]]
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index debd5a66adb..66c2ad84cce 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -831,13 +831,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
e.copy(child = prunedChild(child, e.references))
// prune unrequired references
- case p @ Project(_, g: Generate) if p.references != g.outputSet =>
- val requiredAttrs = p.references -- g.producedAttributes ++
g.generator.references
- val newChild = prunedChild(g.child, requiredAttrs)
- val unrequired = g.generator.references -- p.references
- val unrequiredIndices = newChild.output.zipWithIndex.filter(t =>
unrequired.contains(t._1))
- .map(_._2)
- p.copy(child = g.copy(child = newChild, unrequiredChildIndex =
unrequiredIndices))
+ // There are 2 types of pruning here:
+ // 1. For attributes in g.child.outputSet that is not used by the
generator nor the project,
+ // we directly remove it from the output list of g.child.
+ // 2. For attributes that is not used by the project but it is used by the
generator, we put
+ // it in g.unrequiredChildIndex to save memory usage.
+ case GeneratorUnrequiredChildrenPruning(rewrittenPlan) => rewrittenPlan
// prune unrequired nested fields from `Generate`.
case GeneratorNestedColumnAliasing(rewrittenPlan) => rewrittenPlan
@@ -897,7 +896,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
})
/** Applies a projection only when the child is producing unnecessary
attributes */
- private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
+ def prunedChild(c: LogicalPlan, allReferences: AttributeSet): LogicalPlan =
if (!c.outputSet.subsetOf(allReferences)) {
Project(c.output.filter(allReferences.contains), c)
} else {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 050a2e27036..b1a642e75b8 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
+import
org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasingSuite.collectGeneratedAliases
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -463,4 +464,35 @@ class ColumnPruningSuite extends PlanTest {
val correctAnswer1 = Project(Seq($"a"), input).analyze
comparePlans(Optimize.execute(plan1.analyze), correctAnswer1)
}
+
+ test("SPARK-38531: Nested field pruning for Project and PosExplode") {
+ val name = StructType.fromDDL("first string, middle string, last string")
+ val employer = StructType.fromDDL("id int, company struct<name:string,
address:string>")
+ val contact = LocalRelation(
+ 'id.int,
+ 'name.struct(name),
+ 'address.string,
+ 'friends.array(name),
+ 'relatives.map(StringType, name),
+ 'employer.struct(employer))
+
+ val query = contact
+ .select('id, 'friends)
+ .generate(PosExplode('friends))
+ .select('col.getField("middle"))
+ .analyze
+ val optimized = Optimize.execute(query)
+
+ val aliases = collectGeneratedAliases(optimized)
+
+ val expected = contact
+ // GetStructField is pushed down, unused id column is pruned.
+ .select(
+ 'friends.getField("middle").as(aliases(0)))
+ .generate(PosExplode($"${aliases(0)}"),
+ unrequiredChildIndex = Seq(0)) // unrequiredChildIndex is added.
+ .select('col.as("col.middle"))
+ .analyze
+ comparePlans(optimized, expected)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]