This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ced588  [SPARK-32635][SQL] Fix foldable propagation
4ced588 is described below

commit 4ced58862c707aa916f7a55d15c3887c94c9b210
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Fri Sep 18 08:17:23 2020 +0900

    [SPARK-32635][SQL] Fix foldable propagation
    
    ### What changes were proposed in this pull request?
    This PR rewrites `FoldablePropagation` rule to replace attribute references 
in a node with foldables coming only from the node's children.
    
    Before this PR in the case of this example (with 
setting`spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation`):
    ```scala
    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
    val aub = a.union(b)
    val c = aub.filter($"col1" === "2").cache()
    val d = Seq("2").toDF( "col4")
    val r = d.join(aub, $"col2" === $"col4").select("col4")
    val l = c.select("col2")
    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
    df.show()
    ```
    foldable propagation happens incorrectly:
    ```
     Join LeftOuter, (col2#6 = col4#34)                                         
                     Join LeftOuter, (col2#6 = col4#34)
    !:- Project [col2#6]                                                        
                     :- Project [1 AS col2#6]
     :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, 
deserialized, 1 replicas)   :  +- InMemoryRelation [col1#4, col2#6], 
StorageLevel(disk, memory, deserialized, 1 replicas)
     :        +- Union                                                          
                     :        +- Union
     :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]               
                     :           :- *(1) Project [value#1 AS col1#4, 1 AS 
col2#6]
     :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))       
                     :           :  +- *(1) Filter (isnotnull(value#1) AND 
(value#1 = 2))
     :           :     +- *(1) LocalTableScan [value#1]                         
                     :           :     +- *(1) LocalTableScan [value#1]
     :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]            
                     :           +- *(2) Project [value#10 AS col1#13, 2 AS 
col2#15]
     :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))     
                     :              +- *(2) Filter (isnotnull(value#10) AND 
(value#10 = 2))
     :                 +- *(2) LocalTableScan [value#10]                        
                     :                 +- *(2) LocalTableScan [value#10]
     +- Project [col4#34]                                                       
                     +- Project [col4#34]
        +- Join Inner, (col2#6 = col4#34)                                       
                        +- Join Inner, (col2#6 = col4#34)
           :- Project [value#31 AS col4#34]                                     
                           :- Project [value#31 AS col4#34]
           :  +- LocalRelation [value#31]                                       
                           :  +- LocalRelation [value#31]
           +- Project [col2#6]                                                  
                           +- Project [col2#6]
              +- Union false, false                                             
                              +- Union false, false
                 :- Project [1 AS col2#6]                                       
                                 :- Project [1 AS col2#6]
                 :  +- LocalRelation [value#1]                                  
                                 :  +- LocalRelation [value#1]
                 +- Project [2 AS col2#15]                                      
                                 +- Project [2 AS col2#15]
                    +- LocalRelation [value#10]                                 
                                    +- LocalRelation [value#10]
    
    ```
    and so the result is wrong:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   1|null|
    +----+----+
    ```
    
    After this PR foldable propagation will not happen incorrectly and the 
result is correct:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   2|   2|
    +----+----+
    ```
    
    ### Why are the changes needed?
    To fix a correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixes a correctness issue.
    
    ### How was this patch tested?
    Existing and new UTs.
    
    Closes #29771 from peter-toth/SPARK-32635-fix-foldable-propagation.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
---
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../spark/sql/catalyst/optimizer/expressions.scala | 121 ++++++++++++---------
 .../org/apache/spark/sql/DataFrameSuite.scala      |  12 ++
 4 files changed, 88 insertions(+), 49 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
 
b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 75a8bec..42b92d4 100644
--- 
a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ 
b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
+
+  def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
 }
 
 class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git 
a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
 
b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 4caa3d0..e6b53e3 100644
--- 
a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ 
b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
+
+  def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
 }
 
 class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index b2fc393..c4e4b25 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -624,59 +624,82 @@ object NullPropagation extends Rule[LogicalPlan] {
  */
 object FoldablePropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = {
-    var foldableMap = AttributeMap(plan.flatMap {
-      case Project(projectList, _) => projectList.collect {
-        case a: Alias if a.child.foldable => (a.toAttribute, a)
-      }
-      case _ => Nil
-    })
-    val replaceFoldable: PartialFunction[Expression, Expression] = {
-      case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+    CleanupAliases(propagateFoldables(plan)._1)
+  }
+
+  private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, 
AttributeMap[Alias]) = {
+    plan match {
+      case p: Project =>
+        val (newChild, foldableMap) = propagateFoldables(p.child)
+        val newProject =
+          
replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], 
foldableMap)
+        val newFoldableMap = AttributeMap(newProject.projectList.collect {
+          case a: Alias if a.child.foldable => (a.toAttribute, a)
+        })
+        (newProject, newFoldableMap)
+
+      // We can not replace the attributes in `Expand.output`. If there are 
other non-leaf
+      // operators that have the `output` field, we should put them here too.
+      case e: Expand =>
+        val (newChild, foldableMap) = propagateFoldables(e.child)
+        val expandWithNewChildren = 
e.withNewChildren(Seq(newChild)).asInstanceOf[Expand]
+        val newExpand = if (foldableMap.isEmpty) {
+          expandWithNewChildren
+        } else {
+          val newProjections = 
expandWithNewChildren.projections.map(_.map(_.transform {
+            case a: AttributeReference if foldableMap.contains(a) => 
foldableMap(a)
+          }))
+          if (newProjections == expandWithNewChildren.projections) {
+            expandWithNewChildren
+          } else {
+            expandWithNewChildren.copy(projections = newProjections)
+          }
+        }
+        (newExpand, foldableMap)
+
+      case u: UnaryNode if canPropagateFoldables(u) =>
+        val (newChild, foldableMap) = propagateFoldables(u.child)
+        val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), 
foldableMap)
+        (newU, foldableMap)
+
+      // Join derives the output attributes from its child while they are 
actually not the
+      // same attributes. For example, the output of outer join is not always 
picked from its
+      // children, but can also be null. We should exclude these miss-derived 
attributes when
+      // propagating the foldable expressions.
+      // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
+      // of outer join.
+      case j: Join =>
+        val (newChildren, foldableMaps) = 
j.children.map(propagateFoldables).unzip
+        val foldableMap = AttributeMap(
+          foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ 
_.baseMap.values).toSeq)
+        val newJoin =
+          replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], 
foldableMap)
+        val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType 
match {
+          case _: InnerLike | LeftExistence(_) => Nil
+          case LeftOuter => newJoin.right.output
+          case RightOuter => newJoin.left.output
+          case FullOuter => newJoin.left.output ++ newJoin.right.output
+        })
+        val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot 
{
+          case (attr, _) => missDerivedAttrsSet.contains(attr)
+        }.toSeq)
+        (newJoin, newFoldableMap)
+
+      // For other plans, they are not safe to apply foldable propagation, and 
they should not
+      // propagate foldable expressions from children.
+      case o =>
+        val newOther = o.mapChildren(propagateFoldables(_)._1)
+        (newOther, AttributeMap.empty)
     }
+  }
 
+  private def replaceFoldable(plan: LogicalPlan, foldableMap: 
AttributeMap[Alias]): plan.type = {
     if (foldableMap.isEmpty) {
       plan
     } else {
-      CleanupAliases(plan.transformUp {
-        // We can only propagate foldables for a subset of unary nodes.
-        case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) 
=>
-          u.transformExpressions(replaceFoldable)
-
-        // Join derives the output attributes from its child while they are 
actually not the
-        // same attributes. For example, the output of outer join is not 
always picked from its
-        // children, but can also be null. We should exclude these 
miss-derived attributes when
-        // propagating the foldable expressions.
-        // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
-        // of outer join.
-        case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
-          val newJoin = j.transformExpressions(replaceFoldable)
-          val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
-            case _: InnerLike | LeftExistence(_) => Nil
-            case LeftOuter => right.output
-            case RightOuter => left.output
-            case FullOuter => left.output ++ right.output
-          })
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => missDerivedAttrsSet.contains(attr)
-          }.toSeq)
-          newJoin
-
-        // We can not replace the attributes in `Expand.output`. If there are 
other non-leaf
-        // operators that have the `output` field, we should put them here too.
-        case expand: Expand if foldableMap.nonEmpty =>
-          expand.copy(projections = expand.projections.map { projection =>
-            projection.map(_.transform(replaceFoldable))
-          })
-
-        // For other plans, they are not safe to apply foldable propagation, 
and they should not
-        // propagate foldable expressions from children.
-        case other if foldableMap.nonEmpty =>
-          val childrenOutputSet = 
AttributeSet(other.children.flatMap(_.output))
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => childrenOutputSet.contains(attr)
-          }.toSeq)
-          other
-      })
+      plan transformExpressions {
+        case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+      }
     }
   }
 
@@ -684,7 +707,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
    * List of all [[UnaryNode]]s which allow foldable propagation.
    */
   private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
-    case _: Project => true
+    // Handling `Project` is moved to `propagateFoldables`.
     case _: Filter => true
     case _: SubqueryAlias => true
     case _: Aggregate => true
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d95f09a..321f496 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2555,6 +2555,18 @@ class DataFrameSuite extends QueryTest
     val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
     checkAnswer(df.select($"pos" > $"neg"), Row(false))
   }
+
+  test("SPARK-32635: Replace references with foldables coming only from the 
node's children") {
+    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
+    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
+    val aub = a.union(b)
+    val c = aub.filter($"col1" === "2").cache()
+    val d = Seq("2").toDF("col4")
+    val r = d.join(aub, $"col2" === $"col4").select("col4")
+    val l = c.select("col2")
+    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
+    checkAnswer(df, Row("2", "2"))
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to