Repository: spark
Updated Branches:
  refs/heads/branch-1.5 47bc6c0fa -> edc509586


[SPARK-11009] [SQL] fix wrong result of Window function in cluster mode

Currently, All windows function could generate wrong result in cluster 
sometimes.

The root cause is that AttributeReference is called in executor, then id of it 
may not be unique than others created in driver.

Here is the script that could reproduce the problem (run in local cluster):
```
from pyspark import SparkContext, HiveContext
from pyspark.sql.window import Window
from pyspark.sql.functions import rowNumber

sqlContext = HiveContext(SparkContext())
sqlContext.setConf("spark.sql.shuffle.partitions", "3")
df =  sqlContext.range(1<<20)
df2 = df.select((df.id % 1000).alias("A"), (df.id / 1000).alias('B'))
ws = Window.partitionBy(df2.A).orderBy(df2.B)
df3 = df2.select("client", "date", rowNumber().over(ws).alias("rn")).filter("rn 
< 0")
assert df3.count() == 0
```

Author: Davies Liu <[email protected]>
Author: Yin Huai <[email protected]>

Closes #9050 from davies/wrong_window.

Conflicts:
        
sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/edc50958
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/edc50958
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/edc50958

Branch: refs/heads/branch-1.5
Commit: edc509586c3aee4b44e0659724d2e3d81d8518de
Parents: 47bc6c0
Author: Davies Liu <[email protected]>
Authored: Tue Oct 13 09:40:36 2015 -0700
Committer: Yin Huai <[email protected]>
Committed: Tue Oct 13 09:47:28 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/execution/Window.scala | 20 ++++-----
 .../spark/sql/hive/HiveSparkSubmitSuite.scala   | 43 +++++++++++++++++++-
 2 files changed, 52 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/edc50958/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index f892953..55035f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -145,11 +145,10 @@ case class Window(
         // Construct the ordering. This is used to compare the result of 
current value projection
         // to the result of bound value projection. This is done manually 
because we want to use
         // Code Generation (if it is enabled).
-        val (sortExprs, schema) = exprs.map { case e =>
-          val ref = AttributeReference("ordExpr", e.dataType, e.nullable)()
-          (SortOrder(ref, e.direction), ref)
-        }.unzip
-        val ordering = newOrdering(sortExprs, schema)
+        val sortExprs = exprs.zipWithIndex.map { case (e, i) =>
+          SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction)
+        }
+        val ordering = newOrdering(sortExprs, Nil)
         RangeBoundOrdering(ordering, current, bound)
       case RowFrame => RowBoundOrdering(offset)
     }
@@ -205,14 +204,15 @@ case class Window(
    */
   private[this] def createResultProjection(
       expressions: Seq[Expression]): MutableProjection = {
-    val unboundToAttr = expressions.map {
-      e => (e, AttributeReference("windowResult", e.dataType, e.nullable)())
+    val references = expressions.zipWithIndex.map{ case (e, i) =>
+      // Results of window expressions will be on the right side of child's 
output
+      BoundReference(child.output.size + i, e.dataType, e.nullable)
     }
-    val unboundToAttrMap = unboundToAttr.toMap
-    val patchedWindowExpression = 
windowExpression.map(_.transform(unboundToAttrMap))
+    val unboundToRefMap = expressions.zip(references).toMap
+    val patchedWindowExpression = 
windowExpression.map(_.transform(unboundToRefMap))
     newMutableProjection(
       projectList ++ patchedWindowExpression,
-      child.output ++ unboundToAttr.map(_._2))()
+      child.output)()
   }
 
   protected override def doExecute(): RDD[InternalRow] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/edc50958/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 1d5ee22..788118b 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -29,7 +29,8 @@ import 
org.scalatest.exceptions.TestFailedDueToTimeoutException
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark._
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{SQLContext, QueryTest}
+import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
 import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
 import org.apache.spark.sql.types.DecimalType
@@ -107,6 +108,16 @@ class HiveSparkSubmitSuite
     runSparkSubmit(args)
   }
 
+  test("SPARK-11009 fix wrong result of Window function in cluster mode") {
+    val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+    val args = Seq(
+      "--class", SPARK_11009.getClass.getName.stripSuffix("$"),
+      "--name", "SparkSQLConfTest",
+      "--master", "local-cluster[2,1,1024]",
+      unusedJar.toString)
+    runSparkSubmit(args)
+  }
+
   // NOTE: This is an expensive operation in terms of time (10 seconds+). Use 
sparingly.
   // This is copied from org.apache.spark.deploy.SparkSubmitSuite
   private def runSparkSubmit(args: Seq[String]): Unit = {
@@ -314,3 +325,33 @@ object SPARK_9757 extends QueryTest with Logging {
     }
   }
 }
+
+object SPARK_11009 extends QueryTest {
+  import org.apache.spark.sql.functions._
+
+  protected var sqlContext: SQLContext = _
+
+  def main(args: Array[String]): Unit = {
+    Utils.configTestLog4j("INFO")
+
+    val sparkContext = new SparkContext(
+      new SparkConf()
+        .set("spark.ui.enabled", "false")
+        .set("spark.sql.shuffle.partitions", "100"))
+
+    val hiveContext = new TestHiveContext(sparkContext)
+    sqlContext = hiveContext
+
+    try {
+      val df = sqlContext.range(1 << 20)
+      val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 
1000).alias("B"))
+      val ws = Window.partitionBy(df2("A")).orderBy(df2("B"))
+      val df3 = df2.select(df2("A"), df2("B"), 
rowNumber().over(ws).alias("rn")).filter("rn < 0")
+      if (df3.rdd.count() != 0) {
+        throw new Exception("df3 should have 0 output row.")
+      }
+    } finally {
+      sparkContext.stop()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to