This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 17c56fc03b8 [SPARK-38531][SQL] Fix the condition of "Prune unrequired 
child index" branch of ColumnPruning
17c56fc03b8 is described below

commit 17c56fc03b8e7269b293d6957c542eab9d723d52
Author: minyyy <[email protected]>
AuthorDate: Fri Apr 8 10:43:38 2022 +0800

    [SPARK-38531][SQL] Fix the condition of "Prune unrequired child index" 
branch of ColumnPruning
    
    ### What changes were proposed in this pull request?
    
    The "prune unrequired references" branch has the condition:
    
    `case p  Project(_, g: Generate) if p.references != g.outputSet => `
    
    This is wrong as generators like Inline will always enter this branch as 
long as it does not use all the generator output.
    
    Example:
    
    input: <col1: array<struct<a: struct<a: int, b: int>, b: int>>>
    
    Project(a.a as x)
    \- Generate(Inline(col1), ..., a, b)
    
    p.references is [a]
    g.outputSet is [a, b]
    
    This bug makes us never enter the GeneratorNestedColumnAliasing branch 
below thus miss some optimization opportunities. This PR changes the condition 
to check whether the child output is not used by the project and it is either 
not used by the generator or not already put into unrequiredChildOutput.
    
    ### Why are the changes needed?
    The wrong condition prevents some expressions like Inline, PosExplode from 
being optimized by rules after it. Before the PR, the test query added in the 
PR is not optimized since the optimization rule is not able to apply to it. 
After the PR the optimization rule can be correctly applied to the query.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #35864 from minyyy/gnca_wrong_cond.
    
    Authored-by: minyyy <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 4b9343593eca780ca30ffda45244a71413577884)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/optimizer/NestedColumnAliasing.scala  | 19 +++++++++++++
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 15 +++++-----
 .../catalyst/optimizer/ColumnPruningSuite.scala    | 32 ++++++++++++++++++++++
 3 files changed, 58 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 4c7130e51e0..9cf2925cdd2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -312,6 +312,25 @@ object NestedColumnAliasing {
   }
 }
 
+object GeneratorUnrequiredChildrenPruning {
+  def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+    case p @ Project(_, g: Generate) =>
+      val requiredAttrs = p.references ++ g.generator.references
+      val newChild = ColumnPruning.prunedChild(g.child, requiredAttrs)
+      val unrequired = g.generator.references -- p.references
+      val unrequiredIndices = newChild.output.zipWithIndex.filter(t => 
unrequired.contains(t._1))
+        .map(_._2)
+      if (!newChild.fastEquals(g.child) ||
+        unrequiredIndices.toSet != g.unrequiredChildIndex.toSet) {
+        Some(p.copy(child = g.copy(child = newChild, unrequiredChildIndex = 
unrequiredIndices)))
+      } else {
+        None
+      }
+    case _ => None
+  }
+}
+
+
 /**
  * This prunes unnecessary nested columns from [[Generate]], or [[Project]] -> 
[[Generate]]
  */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index debd5a66adb..66c2ad84cce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -831,13 +831,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
       e.copy(child = prunedChild(child, e.references))
 
     // prune unrequired references
-    case p @ Project(_, g: Generate) if p.references != g.outputSet =>
-      val requiredAttrs = p.references -- g.producedAttributes ++ 
g.generator.references
-      val newChild = prunedChild(g.child, requiredAttrs)
-      val unrequired = g.generator.references -- p.references
-      val unrequiredIndices = newChild.output.zipWithIndex.filter(t => 
unrequired.contains(t._1))
-        .map(_._2)
-      p.copy(child = g.copy(child = newChild, unrequiredChildIndex = 
unrequiredIndices))
+    // There are 2 types of pruning here:
+    // 1. For attributes in g.child.outputSet that is not used by the 
generator nor the project,
+    //    we directly remove it from the output list of g.child.
+    // 2. For attributes that is not used by the project but it is used by the 
generator, we put
+    //    it in g.unrequiredChildIndex to save memory usage.
+    case GeneratorUnrequiredChildrenPruning(rewrittenPlan) => rewrittenPlan
 
     // prune unrequired nested fields from `Generate`.
     case GeneratorNestedColumnAliasing(rewrittenPlan) => rewrittenPlan
@@ -897,7 +896,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
   })
 
   /** Applies a projection only when the child is producing unnecessary 
attributes */
-  private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
+  def prunedChild(c: LogicalPlan, allReferences: AttributeSet): LogicalPlan =
     if (!c.outputSet.subsetOf(allReferences)) {
       Project(c.output.filter(allReferences.contains), c)
     } else {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 0655acbcb1b..0101c855152 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.optimizer.NestedColumnAliasingSuite.collectGeneratedAliases
 import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -459,4 +460,35 @@ class ColumnPruningSuite extends PlanTest {
     val correctAnswer1 = Project(Seq('a), input).analyze
     comparePlans(Optimize.execute(plan1.analyze), correctAnswer1)
   }
+
+  test("SPARK-38531: Nested field pruning for Project and PosExplode") {
+    val name = StructType.fromDDL("first string, middle string, last string")
+    val employer = StructType.fromDDL("id int, company struct<name:string, 
address:string>")
+    val contact = LocalRelation(
+      'id.int,
+      'name.struct(name),
+      'address.string,
+      'friends.array(name),
+      'relatives.map(StringType, name),
+      'employer.struct(employer))
+
+    val query = contact
+      .select('id, 'friends)
+      .generate(PosExplode('friends))
+      .select('col.getField("middle"))
+      .analyze
+    val optimized = Optimize.execute(query)
+
+    val aliases = collectGeneratedAliases(optimized)
+
+    val expected = contact
+      // GetStructField is pushed down, unused id column is pruned.
+      .select(
+        'friends.getField("middle").as(aliases(0)))
+      .generate(PosExplode($"${aliases(0)}"),
+        unrequiredChildIndex = Seq(0)) // unrequiredChildIndex is added.
+      .select('col.as("col.middle"))
+      .analyze
+    comparePlans(optimized, expected)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to