Repository: spark
Updated Branches:
  refs/heads/branch-1.5 c3112a92f -> 54334d378


[SPARK-10740] [SQL] handle nondeterministic expressions correctly for set 
operations

https://issues.apache.org/jira/browse/SPARK-10740

Author: Wenchen Fan <cloud0...@163.com>

Closes #8858 from cloud-fan/non-deter.

(cherry picked from commit 5017c685f484ec256101d1d33bad11d9e0c0f641)
Signed-off-by: Yin Huai <yh...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/54334d37
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/54334d37
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/54334d37

Branch: refs/heads/branch-1.5
Commit: 54334d3784bf2401f30deb8d1e9ce91da8aa468b
Parents: c3112a9
Author: Wenchen Fan <cloud0...@163.com>
Authored: Tue Sep 22 12:14:15 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Sep 22 12:15:16 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      | 69 ++++++++++++++------
 .../optimizer/SetOperationPushDownSuite.scala   |  3 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   | 41 ++++++++++++
 3 files changed, 93 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/54334d37/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ce6abc7..fcb5159 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] {
  * Intersect:
  * It is not safe to pushdown Projections through it because we need to get the
  * intersect of rows by comparing the entire rows. It is fine to pushdown 
Filters
- * because we will not have non-deterministic expressions.
+ * with deterministic condition.
  *
  * Except:
  * It is not safe to pushdown Projections through it because we need to get the
  * intersect of rows by comparing the entire rows. It is fine to pushdown 
Filters
- * because we will not have non-deterministic expressions.
+ * with deterministic condition.
  */
-object SetOperationPushDown extends Rule[LogicalPlan] {
+object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 
   /**
    * Maps Attributes from the left side to the corresponding Attribute on the 
right side.
@@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
     result.asInstanceOf[A]
   }
 
+  /**
+   * Splits the condition expression into small conditions by `And`, and 
partition them by
+   * deterministic, and finally recombine them by `And`. It returns an 
expression containing
+   * all deterministic expressions (the first field of the returned Tuple2) 
and an expression
+   * containing all non-deterministic expressions (the second field of the 
returned Tuple2).
+   */
+  private def partitionByDeterministic(condition: Expression): (Expression, 
Expression) = {
+    val andConditions = splitConjunctivePredicates(condition)
+    andConditions.partition(_.deterministic) match {
+      case (deterministic, nondeterministic) =>
+        deterministic.reduceOption(And).getOrElse(Literal(true)) ->
+        nondeterministic.reduceOption(And).getOrElse(Literal(true))
+    }
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     // Push down filter into union
     case Filter(condition, u @ Union(left, right)) =>
+      val (deterministic, nondeterministic) = 
partitionByDeterministic(condition)
       val rewrites = buildRewrites(u)
-      Union(
-        Filter(condition, left),
-        Filter(pushToRight(condition, rewrites), right))
-
-    // Push down projection through UNION ALL
-    case Project(projectList, u @ Union(left, right)) =>
-      val rewrites = buildRewrites(u)
-      Union(
-        Project(projectList, left),
-        Project(projectList.map(pushToRight(_, rewrites)), right))
+      Filter(nondeterministic,
+        Union(
+          Filter(deterministic, left),
+          Filter(pushToRight(deterministic, rewrites), right)
+        )
+      )
+
+    // Push down deterministic projection through UNION ALL
+    case p @ Project(projectList, u @ Union(left, right)) =>
+      if (projectList.forall(_.deterministic)) {
+        val rewrites = buildRewrites(u)
+        Union(
+          Project(projectList, left),
+          Project(projectList.map(pushToRight(_, rewrites)), right))
+      } else {
+        p
+      }
 
     // Push down filter through INTERSECT
     case Filter(condition, i @ Intersect(left, right)) =>
+      val (deterministic, nondeterministic) = 
partitionByDeterministic(condition)
       val rewrites = buildRewrites(i)
-      Intersect(
-        Filter(condition, left),
-        Filter(pushToRight(condition, rewrites), right))
+      Filter(nondeterministic,
+        Intersect(
+          Filter(deterministic, left),
+          Filter(pushToRight(deterministic, rewrites), right)
+        )
+      )
 
     // Push down filter through EXCEPT
     case Filter(condition, e @ Except(left, right)) =>
+      val (deterministic, nondeterministic) = 
partitionByDeterministic(condition)
       val rewrites = buildRewrites(e)
-      Except(
-        Filter(condition, left),
-        Filter(pushToRight(condition, rewrites), right))
+      Filter(nondeterministic,
+        Except(
+          Filter(deterministic, left),
+          Filter(pushToRight(deterministic, rewrites), right)
+        )
+      )
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/54334d37/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
index 3fca47a..1595ad9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
@@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest {
       Batch("Subqueries", Once,
         EliminateSubQueries) ::
       Batch("Union Pushdown", Once,
-        SetOperationPushDown) :: Nil
+        SetOperationPushDown,
+        SimplifyFilters) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

http://git-wip-us.apache.org/repos/asf/spark/blob/54334d37/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d0f484a..36063d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -896,4 +896,45 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     assert(intersect.count() === 30)
     assert(except.count() === 70)
   }
+
+  test("SPARK-10740: handle nondeterministic expressions correctly for set 
operations") {
+    val df1 = (1 to 20).map(Tuple1.apply).toDF("i")
+    val df2 = (1 to 10).map(Tuple1.apply).toDF("i")
+
+    // When generating expected results at here, we need to follow the 
implementation of
+    // Rand expression.
+    def expected(df: DataFrame): Seq[Row] = {
+      df.rdd.collectPartitions().zipWithIndex.flatMap {
+        case (data, index) =>
+          val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
+          data.filter(_.getInt(0) < rng.nextDouble() * 10)
+      }
+    }
+
+    val union = df1.unionAll(df2)
+    checkAnswer(
+      union.filter('i < rand(7) * 10),
+      expected(union)
+    )
+    checkAnswer(
+      union.select(rand(7)),
+      union.rdd.collectPartitions().zipWithIndex.flatMap {
+        case (data, index) =>
+          val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
+          data.map(_ => rng.nextDouble()).map(i => Row(i))
+      }
+    )
+
+    val intersect = df1.intersect(df2)
+    checkAnswer(
+      intersect.filter('i < rand(7) * 10),
+      expected(intersect)
+    )
+
+    val except = df1.except(df2)
+    checkAnswer(
+      except.filter('i < rand(7) * 10),
+      expected(except)
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to