Repository: spark
Updated Branches:
  refs/heads/master 03545ce6d -> d54d8b863


[SPARK-24696][SQL] ColumnPruning rule fails to remove extra Project

## What changes were proposed in this pull request?

The ColumnPruning rule tries adding an extra Project if an input node produces 
fields more than needed, but as a post-processing step, it needs to remove the 
lower Project in the form of "Project - Filter - Project" otherwise it would 
conflict with PushPredicatesThroughProject and would thus cause a infinite 
optimization loop. The current post-processing method is defined as:
```
  private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan 
transform {
    case p1  Project(_, f  Filter(_, p2  Project(_, child)))
      if p2.outputSet.subsetOf(child.outputSet) =>
      p1.copy(child = f.copy(child = child))
  }
```
This method works well when there is only one Filter but would not if there's 
two or more Filters. In this case, there is a deterministic filter and a 
non-deterministic filter so they stay as separate filter nodes and cannot be 
combined together.

An simplified illustration of the optimization process that forms the infinite 
loop is shown below (F1 stands for the 1st filter, F2 for the 2nd filter, P for 
project, S for scan of relation, PredicatePushDown as abbrev. of 
PushPredicatesThroughProject):
```
                             F1 - F2 - P - S
PredicatePushDown      =>    F1 - P - F2 - S
ColumnPruning          =>    F1 - P - F2 - P - S
                       =>    F1 - P - F2 - S        (Project removed)
PredicatePushDown      =>    P - F1 - F2 - S
ColumnPruning          =>    P - F1 - P - F2 - S
                       =>    P - F1 - P - F2 - P - S
                       =>    P - F1 - F2 - P - S    (only one Project removed)
RemoveRedundantProject =>    F1 - F2 - P - S        (goes back to the loop 
start)
```
So the problem is the ColumnPruning rule adds a Project under a Filter (and 
fails to remove it in the end), and that new Project triggers 
PushPredicateThroughProject. Once the filters have been push through the 
Project, a new Project will be added by the ColumnPruning rule and this goes on 
and on.
The fix should be when adding Projects, the rule applies top-down, but later 
when removing extra Projects, the process should go bottom-up to ensure all 
extra Projects can be matched.

## How was this patch tested?

Added a optimization rule test in ColumnPruningSuite; and a end-to-end test in 
SQLQuerySuite.

Author: maryannxue <maryann...@apache.org>

Closes #21674 from maryannxue/spark-24696.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/797971ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/797971ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/797971ed

Branch: refs/heads/master
Commit: 797971ed42cab41cbc3d039c0af4b26199bff783
Parents: 03545ce
Author: maryannxue <maryann...@apache.org>
Authored: Fri Jun 29 23:46:12 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Fri Jun 29 23:46:12 2018 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala |  1 +
 .../sql/catalyst/optimizer/Optimizer.scala      |  5 +++--
 .../catalyst/optimizer/ColumnPruningSuite.scala |  9 ++++++++-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 21 ++++++++++++++++++++
 4 files changed, 33 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/797971ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index efb2eba..8cf69c6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -149,6 +149,7 @@ package object dsl {
       }
     }
 
+    def rand(e: Long): Expression = Rand(Literal.create(e, LongType))
     def sum(e: Expression): Expression = Sum(e).toAggregateExpression()
     def sumDistinct(e: Expression): Expression = 
Sum(e).toAggregateExpression(isDistinct = true)
     def count(e: Expression): Expression = Count(e).toAggregateExpression()

http://git-wip-us.apache.org/repos/asf/spark/blob/797971ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index aa992de..2cc27d8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -526,9 +526,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
 
   /**
    * The Project before Filter is not necessary but conflict with 
PushPredicatesThroughProject,
-   * so remove it.
+   * so remove it. Since the Projects have been added top-down, we need to 
remove in bottom-up
+   * order, otherwise lower Projects can be missed.
    */
-  private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan 
transform {
+  private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan 
transformUp {
     case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
       if p2.outputSet.subsetOf(child.outputSet) =>
       p1.copy(child = f.copy(child = child))

http://git-wip-us.apache.org/repos/asf/spark/blob/797971ed/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 3f41f4b..8b05ba3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.reflect.runtime.universe.TypeTag
 
-import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -370,5 +369,13 @@ class ColumnPruningSuite extends PlanTest {
     comparePlans(optimized2, expected2.analyze)
   }
 
+  test("SPARK-24696 ColumnPruning rule fails to remove extra Project") {
+    val input = LocalRelation('key.int, 'value.string)
+    val query = input.select('key).where(rand(0L) > 0.5).where('key < 
10).analyze
+    val optimized = Optimize.execute(query)
+    val expected = input.where(rand(0L) > 0.5).where('key < 
10).select('key).analyze
+    comparePlans(optimized, expected)
+  }
+
   // todo: add more tests for column pruning
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/797971ed/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 640affc..dfb9c13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2792,4 +2792,25 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
       }
     }
   }
+
+  test("SPARK-24696 ColumnPruning rule fails to remove extra Project") {
+    withTable("fact_stats", "dim_stats") {
+      val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 
99, 4))
+      val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US"))
+      spark.udf.register("filterND", udf((value: Int) => value > 
2).asNondeterministic)
+      factData.toDF("date_id", "store_id", "product_id", "units_sold")
+        
.write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats")
+      storeData.toDF("store_id", "state_province", "country")
+        .write.mode("overwrite").format("parquet").saveAsTable("dim_stats")
+      val df = sql(
+        """
+          |SELECT f.date_id, f.product_id, f.store_id FROM
+          |(SELECT date_id, product_id, store_id
+          |   FROM fact_stats WHERE filterND(date_id)) AS f
+          |JOIN dim_stats s
+          |ON f.store_id = s.store_id WHERE s.country = 'DE'
+        """.stripMargin)
+      checkAnswer(df, Seq(Row(3, 99, 1)))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to