This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 62708db  [SPARK-32635][SQL][2.4] Fix foldable propagation
62708db is described below

commit 62708db4f90f652cd9bc73998ac5f1e949bd41ac
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Fri Sep 18 10:28:30 2020 -0700

    [SPARK-32635][SQL][2.4] Fix foldable propagation
    
    ### What changes were proposed in this pull request?
    This PR rewrites `FoldablePropagation` rule to replace attribute references 
in a node with foldables coming only from the node's children.
    
    Before this PR in the case of this example (with 
setting`spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation`):
    ```scala
    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
    val aub = a.union(b)
    val c = aub.filter($"col1" === "2").cache()
    val d = Seq("2").toDF( "col4")
    val r = d.join(aub, $"col2" === $"col4").select("col4")
    val l = c.select("col2")
    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
    df.show()
    ```
    foldable propagation happens incorrectly:
    ```
     Join LeftOuter, (col2#6 = col4#34)                                         
                     Join LeftOuter, (col2#6 = col4#34)
    !:- Project [col2#6]                                                        
                     :- Project [1 AS col2#6]
     :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, 
deserialized, 1 replicas)   :  +- InMemoryRelation [col1#4, col2#6], 
StorageLevel(disk, memory, deserialized, 1 replicas)
     :        +- Union                                                          
                     :        +- Union
     :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]               
                     :           :- *(1) Project [value#1 AS col1#4, 1 AS 
col2#6]
     :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))       
                     :           :  +- *(1) Filter (isnotnull(value#1) AND 
(value#1 = 2))
     :           :     +- *(1) LocalTableScan [value#1]                         
                     :           :     +- *(1) LocalTableScan [value#1]
     :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]            
                     :           +- *(2) Project [value#10 AS col1#13, 2 AS 
col2#15]
     :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))     
                     :              +- *(2) Filter (isnotnull(value#10) AND 
(value#10 = 2))
     :                 +- *(2) LocalTableScan [value#10]                        
                     :                 +- *(2) LocalTableScan [value#10]
     +- Project [col4#34]                                                       
                     +- Project [col4#34]
        +- Join Inner, (col2#6 = col4#34)                                       
                        +- Join Inner, (col2#6 = col4#34)
           :- Project [value#31 AS col4#34]                                     
                           :- Project [value#31 AS col4#34]
           :  +- LocalRelation [value#31]                                       
                           :  +- LocalRelation [value#31]
           +- Project [col2#6]                                                  
                           +- Project [col2#6]
              +- Union false, false                                             
                              +- Union false, false
                 :- Project [1 AS col2#6]                                       
                                 :- Project [1 AS col2#6]
                 :  +- LocalRelation [value#1]                                  
                                 :  +- LocalRelation [value#1]
                 +- Project [2 AS col2#15]                                      
                                 +- Project [2 AS col2#15]
                    +- LocalRelation [value#10]                                 
                                    +- LocalRelation [value#10]
    
    ```
    and so the result is wrong:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   1|null|
    +----+----+
    ```
    
    After this PR foldable propagation will not happen incorrectly and the 
result is correct:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   2|   2|
    +----+----+
    ```
    
    ### Why are the changes needed?
    To fix a correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixes a correctness issue.
    
    ### How was this patch tested?
    Existing and new UTs.
    
    Closes #29805 from peter-toth/SPARK-32635-fix-foldable-propagation-2.4.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../spark/sql/catalyst/optimizer/expressions.scala | 121 ++++++++++++---------
 .../optimizer/FoldablePropagationSuite.scala       |  12 ++
 .../org/apache/spark/sql/DataFrameSuite.scala      |  12 ++
 4 files changed, 98 insertions(+), 49 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 9f4a0f2..1e8f8ca 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
+
+  def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
 }
 
 class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index be0e702..6c6f6313 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -595,59 +595,82 @@ object NullPropagation extends Rule[LogicalPlan] {
  */
 object FoldablePropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = {
-    var foldableMap = AttributeMap(plan.flatMap {
-      case Project(projectList, _) => projectList.collect {
-        case a: Alias if a.child.foldable => (a.toAttribute, a)
-      }
-      case _ => Nil
-    })
-    val replaceFoldable: PartialFunction[Expression, Expression] = {
-      case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+    CleanupAliases(propagateFoldables(plan)._1)
+  }
+
+  private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, 
AttributeMap[Alias]) = {
+    plan match {
+      case p: Project =>
+        val (newChild, foldableMap) = propagateFoldables(p.child)
+        val newProject =
+          
replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], 
foldableMap)
+        val newFoldableMap = AttributeMap(newProject.projectList.collect {
+          case a: Alias if a.child.foldable => (a.toAttribute, a)
+        })
+        (newProject, newFoldableMap)
+
+      // We can not replace the attributes in `Expand.output`. If there are 
other non-leaf
+      // operators that have the `output` field, we should put them here too.
+      case e: Expand =>
+        val (newChild, foldableMap) = propagateFoldables(e.child)
+        val expandWithNewChildren = 
e.withNewChildren(Seq(newChild)).asInstanceOf[Expand]
+        val newExpand = if (foldableMap.isEmpty) {
+          expandWithNewChildren
+        } else {
+          val newProjections = 
expandWithNewChildren.projections.map(_.map(_.transform {
+            case a: AttributeReference if foldableMap.contains(a) => 
foldableMap(a)
+          }))
+          if (newProjections == expandWithNewChildren.projections) {
+            expandWithNewChildren
+          } else {
+            expandWithNewChildren.copy(projections = newProjections)
+          }
+        }
+        (newExpand, foldableMap)
+
+      case u: UnaryNode if canPropagateFoldables(u) =>
+        val (newChild, foldableMap) = propagateFoldables(u.child)
+        val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), 
foldableMap)
+        (newU, foldableMap)
+
+      // Join derives the output attributes from its child while they are 
actually not the
+      // same attributes. For example, the output of outer join is not always 
picked from its
+      // children, but can also be null. We should exclude these miss-derived 
attributes when
+      // propagating the foldable expressions.
+      // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
+      // of outer join.
+      case j: Join =>
+        val (newChildren, foldableMaps) = 
j.children.map(propagateFoldables).unzip
+        val foldableMap = AttributeMap(
+          foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ 
_.baseMap.values).toSeq)
+        val newJoin =
+          replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], 
foldableMap)
+        val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType 
match {
+          case _: InnerLike | LeftExistence(_) => Nil
+          case LeftOuter => newJoin.right.output
+          case RightOuter => newJoin.left.output
+          case FullOuter => newJoin.left.output ++ newJoin.right.output
+        })
+        val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot 
{
+          case (attr, _) => missDerivedAttrsSet.contains(attr)
+        }.toSeq)
+        (newJoin, newFoldableMap)
+
+      // For other plans, they are not safe to apply foldable propagation, and 
they should not
+      // propagate foldable expressions from children.
+      case o =>
+        val newOther = o.mapChildren(propagateFoldables(_)._1)
+        (newOther, AttributeMap.empty)
     }
+  }
 
+  private def replaceFoldable(plan: LogicalPlan, foldableMap: 
AttributeMap[Alias]): plan.type = {
     if (foldableMap.isEmpty) {
       plan
     } else {
-      CleanupAliases(plan.transformUp {
-        // We can only propagate foldables for a subset of unary nodes.
-        case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) 
=>
-          u.transformExpressions(replaceFoldable)
-
-        // Join derives the output attributes from its child while they are 
actually not the
-        // same attributes. For example, the output of outer join is not 
always picked from its
-        // children, but can also be null. We should exclude these 
miss-derived attributes when
-        // propagating the foldable expressions.
-        // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
-        // of outer join.
-        case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty =>
-          val newJoin = j.transformExpressions(replaceFoldable)
-          val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
-            case _: InnerLike | LeftExistence(_) => Nil
-            case LeftOuter => right.output
-            case RightOuter => left.output
-            case FullOuter => left.output ++ right.output
-          })
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => missDerivedAttrsSet.contains(attr)
-          }.toSeq)
-          newJoin
-
-        // We can not replace the attributes in `Expand.output`. If there are 
other non-leaf
-        // operators that have the `output` field, we should put them here too.
-        case expand: Expand if foldableMap.nonEmpty =>
-          expand.copy(projections = expand.projections.map { projection =>
-            projection.map(_.transform(replaceFoldable))
-          })
-
-        // For other plans, they are not safe to apply foldable propagation, 
and they should not
-        // propagate foldable expressions from children.
-        case other if foldableMap.nonEmpty =>
-          val childrenOutputSet = 
AttributeSet(other.children.flatMap(_.output))
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => childrenOutputSet.contains(attr)
-          }.toSeq)
-          other
-      })
+      plan transformExpressions {
+        case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+      }
     }
   }
 
@@ -655,7 +678,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
    * Whitelist of all [[UnaryNode]]s for which allow foldable propagation.
    */
   private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
-    case _: Project => true
+    // Handling `Project` is moved to `propagateFoldables`.
     case _: Filter => true
     case _: SubqueryAlias => true
     case _: Aggregate => true
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
index c288446..2c45199 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
@@ -180,4 +180,16 @@ class FoldablePropagationSuite extends PlanTest {
       .select((Literal(1) + 3).as('res)).analyze
     comparePlans(optimized, correctAnswer)
   }
+
+  test("SPARK-32635: Replace references with foldables coming only from the 
node's children") {
+    val leftExpression = 'a.int
+    val left = LocalRelation(leftExpression).select('a)
+    val rightExpression = Alias(Literal(2), "a")(leftExpression.exprId)
+    val right = LocalRelation('b.int).select('b, rightExpression).select('b)
+    val join = left.join(right, joinType = LeftOuter, condition = Some('b === 
'a))
+
+    val query = join.analyze
+    val optimized = Optimize.execute(query)
+    comparePlans(optimized, query)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e7d55ee..037cf23 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2720,6 +2720,18 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     val nestedDecArray = Array(decSpark)
     checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
   }
+
+  test("SPARK-32635: Replace references with foldables coming only from the 
node's children") {
+    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
+    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
+    val aub = a.union(b)
+    val c = aub.filter($"col1" === "2").cache()
+    val d = Seq("2").toDF("col4")
+    val r = d.join(aub, $"col2" === $"col4").select("col4")
+    val l = c.select("col2")
+    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
+    checkAnswer(df, Row("2", "2"))
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to