Repository: spark
Updated Branches:
  refs/heads/master 7bb64aae2 -> 8640cdb83


[SPARK-15441][SQL] support null object in Dataset outer-join

## What changes were proposed in this pull request?

Currently we can't encode top level null object into internal row, as Spark SQL 
doesn't allow row to be null, only its columns can be null.

This is not a problem before, as we assume the input object is never null. 
However, for outer join, we do need the semantics of null object.

This PR fixes this problem by making both join sides produce a single column, 
i.e. nest the logical plan output(by `CreateStruct`), so that we have an extra 
level to represent top level null obejct.

## How was this patch tested?

new test in `DatasetSuite`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #13425 from cloud-fan/outer-join2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8640cdb8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8640cdb8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8640cdb8

Branch: refs/heads/master
Commit: 8640cdb836b4964e4af891d9959af64a2e1f304e
Parents: 7bb64aa
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Wed Jun 1 16:16:54 2016 -0700
Committer: Cheng Lian <l...@databricks.com>
Committed: Wed Jun 1 16:16:54 2016 -0700

----------------------------------------------------------------------
 .../catalyst/encoders/ExpressionEncoder.scala   |  3 +-
 .../catalyst/expressions/objects/objects.scala  |  1 -
 .../scala/org/apache/spark/sql/Dataset.scala    | 67 ++++++++++++++------
 .../org/apache/spark/sql/DatasetSuite.scala     | 23 +++----
 4 files changed, 59 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8640cdb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index f21a39a..2296946 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -125,12 +125,13 @@ object ExpressionEncoder {
         }
       } else {
         val input = BoundReference(index, enc.schema, nullable = true)
-        enc.deserializer.transformUp {
+        val deserialized = enc.deserializer.transformUp {
           case UnresolvedAttribute(nameParts) =>
             assert(nameParts.length == 1)
             UnresolvedExtractValue(input, Literal(nameParts.head))
           case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
         }
+        If(IsNull(input), Literal.create(null, deserialized.dataType), 
deserialized)
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8640cdb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 2f2323f..c2e3ab8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.objects
 
 import java.lang.reflect.Modifier
 
-import scala.annotation.tailrec
 import scala.language.existentials
 import scala.reflect.ClassTag
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8640cdb8/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3a6ec45..369b772 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -747,31 +747,62 @@ class Dataset[T] private[sql](
    */
   @Experimental
   def joinWith[U](other: Dataset[U], condition: Column, joinType: String): 
Dataset[(T, U)] = {
-    val left = this.logicalPlan
-    val right = other.logicalPlan
-
-    val joined = sparkSession.sessionState.executePlan(Join(left, right, 
joinType =
-      JoinType(joinType), Some(condition.expr)))
-    val leftOutput = joined.analyzed.output.take(left.output.length)
-    val rightOutput = joined.analyzed.output.takeRight(right.output.length)
+    // Creates a Join node and resolve it first, to get join condition 
resolved, self-join resolved,
+    // etc.
+    val joined = sparkSession.sessionState.executePlan(
+      Join(
+        this.logicalPlan,
+        other.logicalPlan,
+        JoinType(joinType),
+        Some(condition.expr))).analyzed.asInstanceOf[Join]
+
+    // For both join side, combine all outputs into a single column and alias 
it with "_1" or "_2",
+    // to match the schema for the encoder of the join result.
+    // Note that we do this before joining them, to enable the join operator 
to return null for one
+    // side, in cases like outer-join.
+    val left = {
+      val combined = if (this.unresolvedTEncoder.flat) {
+        assert(joined.left.output.length == 1)
+        Alias(joined.left.output.head, "_1")()
+      } else {
+        Alias(CreateStruct(joined.left.output), "_1")()
+      }
+      Project(combined :: Nil, joined.left)
+    }
 
-    val leftData = this.unresolvedTEncoder match {
-      case e if e.flat => Alias(leftOutput.head, "_1")()
-      case _ => Alias(CreateStruct(leftOutput), "_1")()
+    val right = {
+      val combined = if (other.unresolvedTEncoder.flat) {
+        assert(joined.right.output.length == 1)
+        Alias(joined.right.output.head, "_2")()
+      } else {
+        Alias(CreateStruct(joined.right.output), "_2")()
+      }
+      Project(combined :: Nil, joined.right)
     }
-    val rightData = other.unresolvedTEncoder match {
-      case e if e.flat => Alias(rightOutput.head, "_2")()
-      case _ => Alias(CreateStruct(rightOutput), "_2")()
+
+    // Rewrites the join condition to make the attribute point to correct 
column/field, after we
+    // combine the outputs of each join side.
+    val conditionExpr = joined.condition.get transformUp {
+      case a: Attribute if joined.left.outputSet.contains(a) =>
+        if (this.unresolvedTEncoder.flat) {
+          left.output.head
+        } else {
+          val index = joined.left.output.indexWhere(_.exprId == a.exprId)
+          GetStructField(left.output.head, index)
+        }
+      case a: Attribute if joined.right.outputSet.contains(a) =>
+        if (other.unresolvedTEncoder.flat) {
+          right.output.head
+        } else {
+          val index = joined.right.output.indexWhere(_.exprId == a.exprId)
+          GetStructField(right.output.head, index)
+        }
     }
 
     implicit val tuple2Encoder: Encoder[(T, U)] =
       ExpressionEncoder.tuple(this.unresolvedTEncoder, 
other.unresolvedTEncoder)
 
-    withTypedPlan {
-      Project(
-        leftData :: rightData :: Nil,
-        joined.analyzed)
-    }
+    withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr)))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/8640cdb8/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8fc4dc9..0b6874e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -253,21 +253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       (1, 1), (2, 2))
   }
 
-  test("joinWith, expression condition, outer join") {
-    val nullInteger = null.asInstanceOf[Integer]
-    val nullString = null.asInstanceOf[String]
-    val ds1 = Seq(ClassNullableData("a", 1),
-      ClassNullableData("c", 3)).toDS()
-    val ds2 = Seq(("a", new Integer(1)),
-      ("b", new Integer(2))).toDS()
-
-    checkDataset(
-      ds1.joinWith(ds2, $"_1" === $"a", "outer"),
-      (ClassNullableData("a", 1), ("a", new Integer(1))),
-      (ClassNullableData("c", 3), (nullString, nullInteger)),
-      (ClassNullableData(nullString, nullInteger), ("b", new Integer(2))))
-  }
-
   test("joinWith tuple with primitive, expression") {
     val ds1 = Seq(1, 1, 2).toDS()
     val ds2 = Seq(("a", 1), ("b", 2)).toDS()
@@ -783,6 +768,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       ds.filter(_.b > 1).collect().toSeq
     }
   }
+
+  test("SPARK-15441: Dataset outer join") {
+    val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left")
+    val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right")
+    val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
+    val result = joined.collect().toSet
+    assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> 
ClassData("x", 2)))
+  }
 }
 
 case class Generic[T](id: T, value: Double)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to