Repository: spark
Updated Branches:
  refs/heads/branch-1.5 d83dcc9a0 -> 6b1e5c2db


[SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong 
results

https://issues.apache.org/jira/browse/SPARK-10737

Author: Yin Huai <yh...@databricks.com>

Closes #8854 from yhuai/SMJBug.

(cherry picked from commit 5aea987c904b281d7952ad8db40a32561b4ec5cf)
Signed-off-by: Yin Huai <yh...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6b1e5c2d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6b1e5c2d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6b1e5c2d

Branch: refs/heads/branch-1.5
Commit: 6b1e5c2dbaf19729d6bb650bb0d0f5fe7a58f703
Parents: d83dcc9
Author: Yin Huai <yh...@databricks.com>
Authored: Tue Sep 22 13:31:35 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Sep 22 13:31:45 2015 -0700

----------------------------------------------------------------------
 .../codegen/GenerateProjection.scala            |  2 ++
 .../org/apache/spark/sql/execution/Window.scala |  9 +++++--
 .../sql/execution/joins/SortMergeJoin.scala     | 25 ++++++++++++++---
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 28 ++++++++++++++++++++
 4 files changed, 59 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6b1e5c2d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index c744e84..da85caf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -171,6 +171,8 @@ object GenerateProjection extends 
CodeGenerator[Seq[Expression], Projection] {
 
       @Override
       public Object apply(Object r) {
+        // GenerateProjection does not work with UnsafeRows.
+        assert(!(r instanceof ${classOf[UnsafeRow].getName}));
         return new SpecificRow((InternalRow) r);
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6b1e5c2d/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 0269d6d..f892953 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -253,7 +253,11 @@ case class Window(
 
         // Get all relevant projections.
         val result = createResultProjection(unboundExpressions)
-        val grouping = newProjection(partitionSpec, child.output)
+        val grouping = if (child.outputsUnsafeRows) {
+          UnsafeProjection.create(partitionSpec, child.output)
+        } else {
+          newProjection(partitionSpec, child.output)
+        }
 
         // Manage the stream and the grouping.
         var nextRow: InternalRow = EmptyRow
@@ -277,7 +281,8 @@ case class Window(
         val numFrames = frames.length
         private[this] def fetchNextPartition() {
           // Collect all the rows in the current partition.
-          val currentGroup = nextGroup
+          // Before we start to fetch new input rows, make a copy of nextGroup.
+          val currentGroup = nextGroup.copy()
           rows = new CompactBuffer
           while (nextRowAvailable && nextGroup == currentGroup) {
             rows += nextRow.copy()

http://git-wip-us.apache.org/repos/asf/spark/blob/6b1e5c2d/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 6b73226..69afb6b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -56,9 +56,6 @@ case class SortMergeJoin(
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
     requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
 
-  @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, 
left.output)
-  @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, 
right.output)
-
   protected[this] def isUnsafeMode: Boolean = {
     (codegenEnabled && unsafeEnabled
       && UnsafeProjection.canSupport(leftKeys)
@@ -82,6 +79,28 @@ case class SortMergeJoin(
 
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
       new RowIterator {
+        // The projection used to extract keys from input rows of the left 
child.
+        private[this] val leftKeyGenerator = {
+          if (isUnsafeMode) {
+            // It is very important to use UnsafeProjection if input rows are 
UnsafeRows.
+            // Otherwise, GenerateProjection will cause wrong results.
+            UnsafeProjection.create(leftKeys, left.output)
+          } else {
+            newProjection(leftKeys, left.output)
+          }
+        }
+
+        // The projection used to extract keys from input rows of the right 
child.
+        private[this] val rightKeyGenerator = {
+          if (isUnsafeMode) {
+            // It is very important to use UnsafeProjection if input rows are 
UnsafeRows.
+            // Otherwise, GenerateProjection will cause wrong results.
+            UnsafeProjection.create(rightKeys, right.output)
+          } else {
+            newProjection(rightKeys, right.output)
+          }
+        }
+
         // An ordering that can be used to compare keys from both sides.
         private[this] val keyOrdering = 
newNaturalAscendingOrdering(leftKeys.map(_.dataType))
         private[this] var currentLeftRow: InternalRow = _

http://git-wip-us.apache.org/repos/asf/spark/blob/6b1e5c2d/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 9e172b2..4f31bd0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1717,4 +1717,32 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(
       sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), 
Seq(Row(1), Row(0)))
   }
+
+  test("SortMergeJoin returns wrong results when using UnsafeRows") {
+    // This test is for the fix of 
https://issues.apache.org/jira/browse/SPARK-10737.
+    // This bug will be triggered when Tungsten is enabled and there are 
multiple
+    // SortMergeJoin operators executed in the same task.
+    val confs =
+      SQLConf.SORTMERGE_JOIN.key -> "true" ::
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" ::
+        SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil
+    withSQLConf(confs: _*) {
+      val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j")
+      val df2 =
+        df1
+          .join(df1.select(df1("i")), "i")
+          .select(df1("i"), df1("j"))
+
+      val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1")
+      val df4 =
+        df2
+          .join(df3, df2("i") === df3("i1"))
+          .withColumn("diff", $"j" - $"j1")
+          .select(df2("i"), df2("j"), $"diff")
+
+      checkAnswer(
+        df4,
+        df1.withColumn("diff", lit(0)))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to