Repository: spark
Updated Branches:
  refs/heads/branch-1.6 a72abebcb -> ad2ebe4db


[SPARK-11803][SQL] fix Dataset self-join

When we resolve the join operator, we may change the output of right side if 
self-join is detected. So in `Dataset.joinWith`, we should resolve the join 
operator first, and then get the left output and right output from it, instead 
of using `left.output` and `right.output` directly.

Author: Wenchen Fan <[email protected]>

Closes #9806 from cloud-fan/self-join.

(cherry picked from commit cffb899c4397ecccedbcc41e7cf3da91f953435a)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ad2ebe4d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ad2ebe4d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ad2ebe4d

Branch: refs/heads/branch-1.6
Commit: ad2ebe4db52c0d0267fd7e9610775db61dcd7706
Parents: a72abeb
Author: Wenchen Fan <[email protected]>
Authored: Wed Nov 18 10:15:50 2015 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Wed Nov 18 10:16:01 2015 -0800

----------------------------------------------------------------------
 .../src/main/scala/org/apache/spark/sql/Dataset.scala | 14 +++++++++-----
 .../scala/org/apache/spark/sql/DatasetSuite.scala     |  8 ++++----
 2 files changed, 13 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ad2ebe4d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 817c20f..b644f6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -498,13 +498,17 @@ class Dataset[T] private[sql](
     val left = this.logicalPlan
     val right = other.logicalPlan
 
+    val joined = sqlContext.executePlan(Join(left, right, Inner, 
Some(condition.expr)))
+    val leftOutput = joined.analyzed.output.take(left.output.length)
+    val rightOutput = joined.analyzed.output.takeRight(right.output.length)
+
     val leftData = this.unresolvedTEncoder match {
-      case e if e.flat => Alias(left.output.head, "_1")()
-      case _ => Alias(CreateStruct(left.output), "_1")()
+      case e if e.flat => Alias(leftOutput.head, "_1")()
+      case _ => Alias(CreateStruct(leftOutput), "_1")()
     }
     val rightData = other.unresolvedTEncoder match {
-      case e if e.flat => Alias(right.output.head, "_2")()
-      case _ => Alias(CreateStruct(right.output), "_2")()
+      case e if e.flat => Alias(rightOutput.head, "_2")()
+      case _ => Alias(CreateStruct(rightOutput), "_2")()
     }
 
 
@@ -513,7 +517,7 @@ class Dataset[T] private[sql](
     withPlan[(T, U)](other) { (left, right) =>
       Project(
         leftData :: rightData :: Nil,
-        Join(left, right, Inner, Some(condition.expr)))
+        joined.analyzed)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ad2ebe4d/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a522894..198962b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -347,7 +347,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     checkAnswer(joined, ("2", 2))
   }
 
-  ignore("self join") {
+  test("self join") {
     val ds = Seq("1", "2").toDS().as("a")
     val joined = ds.joinWith(ds, lit(true))
     checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
@@ -360,15 +360,15 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
 
   test("kryo encoder") {
     implicit val kryoEncoder = Encoders.kryo[KryoData]
-    val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+    val ds = Seq(KryoData(1), KryoData(2)).toDS()
 
     assert(ds.groupBy(p => p).count().collect().toSeq ==
       Seq((KryoData(1), 1L), (KryoData(2), 1L)))
   }
 
-  ignore("kryo encoder self join") {
+  test("kryo encoder self join") {
     implicit val kryoEncoder = Encoders.kryo[KryoData]
-    val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+    val ds = Seq(KryoData(1), KryoData(2)).toDS()
     assert(ds.joinWith(ds, lit(true)).collect().toSet ==
       Set(
         (KryoData(1), KryoData(1)),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to