Repository: spark
Updated Branches:
  refs/heads/master b13ef7723 -> 170723860


[SPARK-7026] [SQL] fix left semi join with equi key and non-equi condition

When the `condition` extracted by `ExtractEquiJoinKeys` contain join Predicate 
for left semi join, we can not plan it as semiJoin. Such as

    SELECT * FROM testData2 x
    LEFT SEMI JOIN testData2 y
    ON x.b = y.b
    AND x.a >= y.a + 2
Condition `x.a >= y.a + 2` can not evaluate on table `x`, so it throw errors

Author: Daoyuan Wang <daoyuan.w...@intel.com>

Closes #5643 from adrian-wang/spark7026 and squashes the following commits:

cc09809 [Daoyuan Wang] refactor semijoin and add plan test
575a7c8 [Daoyuan Wang] fix notserializable
27841de [Daoyuan Wang] fix rebase
10bf124 [Daoyuan Wang] fix style
72baa02 [Daoyuan Wang] fix style
8e0afca [Daoyuan Wang] merge commits for rebase


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/17072386
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/17072386
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/17072386

Branch: refs/heads/master
Commit: 1707238601690fd0e8e173e2c47f1b4286644a29
Parents: b13ef77
Author: Daoyuan Wang <daoyuan.w...@intel.com>
Authored: Fri Jul 17 16:45:46 2015 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Jul 17 16:45:46 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   | 10 +--
 .../joins/BroadcastLeftSemiJoinHash.scala       | 42 ++++-----
 .../sql/execution/joins/HashOuterJoin.scala     |  3 +-
 .../sql/execution/joins/HashSemiJoin.scala      | 91 ++++++++++++++++++++
 .../sql/execution/joins/LeftSemiJoinHash.scala  | 35 +++-----
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 12 +++
 .../sql/execution/joins/SemiJoinSuite.scala     | 74 ++++++++++++++++
 7 files changed, 208 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/17072386/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 73b4634..240332a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -38,14 +38,12 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, 
right)
         if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
           right.statistics.sizeInBytes <= 
sqlContext.conf.autoBroadcastJoinThreshold =>
-        val semiJoin = joins.BroadcastLeftSemiJoinHash(
-          leftKeys, rightKeys, planLater(left), planLater(right))
-        condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
+        joins.BroadcastLeftSemiJoinHash(
+          leftKeys, rightKeys, planLater(left), planLater(right), condition) 
:: Nil
       // Find left semi joins where at least some predicates can be evaluated 
by matching join keys
       case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, 
right) =>
-        val semiJoin = joins.LeftSemiJoinHash(
-          leftKeys, rightKeys, planLater(left), planLater(right))
-        condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
+        joins.LeftSemiJoinHash(
+          leftKeys, rightKeys, planLater(left), planLater(right), condition) 
:: Nil
       // no predicate can be evaluated by matching hash keys
       case logical.Join(left, right, LeftSemi, condition) =>
         joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: 
Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/17072386/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index f7b46d6..2750f58 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -33,37 +33,27 @@ case class BroadcastLeftSemiJoinHash(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     left: SparkPlan,
-    right: SparkPlan) extends BinaryNode with HashJoin {
-
-  override val buildSide: BuildSide = BuildRight
-
-  override def output: Seq[Attribute] = left.output
+    right: SparkPlan,
+    condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
 
   protected override def doExecute(): RDD[InternalRow] = {
-    val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator
-    val hashSet = new java.util.HashSet[InternalRow]()
-    var currentRow: InternalRow = null
+    val buildIter = right.execute().map(_.copy()).collect().toIterator
 
-    // Create a Hash set of buildKeys
-    while (buildIter.hasNext) {
-      currentRow = buildIter.next()
-      val rowKey = buildSideKeyGenerator(currentRow)
-      if (!rowKey.anyNull) {
-        val keyExists = hashSet.contains(rowKey)
-        if (!keyExists) {
-          // rowKey may be not serializable (from codegen)
-          hashSet.add(rowKey.copy())
-        }
-      }
-    }
+    if (condition.isEmpty) {
+      // rowKey may be not serializable (from codegen)
+      val hashSet = buildKeyHashSet(buildIter, copy = true)
+      val broadcastedRelation = sparkContext.broadcast(hashSet)
 
-    val broadcastedRelation = sparkContext.broadcast(hashSet)
+      left.execute().mapPartitions { streamIter =>
+        hashSemiJoin(streamIter, broadcastedRelation.value)
+      }
+    } else {
+      val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+      val broadcastedRelation = sparkContext.broadcast(hashRelation)
 
-    streamedPlan.execute().mapPartitions { streamIter =>
-      val joinKeys = streamSideKeyGenerator()
-      streamIter.filter(current => {
-        !joinKeys(current).anyNull && 
broadcastedRelation.value.contains(joinKeys.currentValue)
-      })
+      left.execute().mapPartitions { streamIter =>
+        hashSemiJoin(streamIter, broadcastedRelation.value)
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/17072386/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 0522ee8..74a7db7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -65,8 +65,7 @@ override def outputPartitioning: Partitioning = joinType 
match {
   @transient private[this] lazy val leftNullRow = new 
GenericInternalRow(left.output.length)
   @transient private[this] lazy val rightNullRow = new 
GenericInternalRow(right.output.length)
   @transient private[this] lazy val boundCondition =
-    condition.map(
-      newPredicate(_, left.output ++ right.output)).getOrElse((row: 
InternalRow) => true)
+    newPredicate(condition.getOrElse(Literal(true)), left.output ++ 
right.output)
 
   // TODO we need to rewrite all of the iterators with our own implementation 
instead of the Scala
   // iterator for performance purpose.

http://git-wip-us.apache.org/repos/asf/spark/blob/17072386/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
new file mode 100644
index 0000000..1b983bc
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+
+
+trait HashSemiJoin {
+  self: SparkPlan =>
+  val leftKeys: Seq[Expression]
+  val rightKeys: Seq[Expression]
+  val left: SparkPlan
+  val right: SparkPlan
+  val condition: Option[Expression]
+
+  override def output: Seq[Attribute] = left.output
+
+  @transient protected lazy val rightKeyGenerator: Projection =
+    newProjection(rightKeys, right.output)
+
+  @transient protected lazy val leftKeyGenerator: () => MutableProjection =
+    newMutableProjection(leftKeys, left.output)
+
+  @transient private lazy val boundCondition =
+    newPredicate(condition.getOrElse(Literal(true)), left.output ++ 
right.output)
+
+  protected def buildKeyHashSet(
+      buildIter: Iterator[InternalRow],
+      copy: Boolean): java.util.Set[InternalRow] = {
+    val hashSet = new java.util.HashSet[InternalRow]()
+    var currentRow: InternalRow = null
+
+    // Create a Hash set of buildKeys
+    while (buildIter.hasNext) {
+      currentRow = buildIter.next()
+      val rowKey = rightKeyGenerator(currentRow)
+      if (!rowKey.anyNull) {
+        val keyExists = hashSet.contains(rowKey)
+        if (!keyExists) {
+          if (copy) {
+            hashSet.add(rowKey.copy())
+          } else {
+            // rowKey may be not serializable (from codegen)
+            hashSet.add(rowKey)
+          }
+        }
+      }
+    }
+    hashSet
+  }
+
+  protected def hashSemiJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation): Iterator[InternalRow] = {
+    val joinKeys = leftKeyGenerator()
+    val joinedRow = new JoinedRow
+    streamIter.filter(current => {
+      lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
+      !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
+        (build: InternalRow) => boundCondition(joinedRow(current, build))
+      }
+    })
+  }
+
+  protected def hashSemiJoin(
+      streamIter: Iterator[InternalRow],
+      hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
+    val joinKeys = leftKeyGenerator()
+    val joinedRow = new JoinedRow
+    streamIter.filter(current => {
+      !joinKeys(current.copy()).anyNull && 
hashSet.contains(joinKeys.currentValue)
+    })
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/17072386/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 611ba92..9eaac81 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 
@@ -34,36 +34,21 @@ case class LeftSemiJoinHash(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     left: SparkPlan,
-    right: SparkPlan) extends BinaryNode with HashJoin {
-
-  override val buildSide: BuildSide = BuildRight
+    right: SparkPlan,
+    condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
 
   override def requiredChildDistribution: Seq[ClusteredDistribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
-  override def output: Seq[Attribute] = left.output
-
   protected override def doExecute(): RDD[InternalRow] = {
-    buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, 
streamIter) =>
-      val hashSet = new java.util.HashSet[InternalRow]()
-      var currentRow: InternalRow = null
-
-      // Create a Hash set of buildKeys
-      while (buildIter.hasNext) {
-        currentRow = buildIter.next()
-        val rowKey = buildSideKeyGenerator(currentRow)
-        if (!rowKey.anyNull) {
-          val keyExists = hashSet.contains(rowKey)
-          if (!keyExists) {
-            hashSet.add(rowKey)
-          }
-        }
+    right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
+      if (condition.isEmpty) {
+        val hashSet = buildKeyHashSet(buildIter, copy = false)
+        hashSemiJoin(streamIter, hashSet)
+      } else {
+        val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+        hashSemiJoin(streamIter, hashRelation)
       }
-
-      val joinKeys = streamSideKeyGenerator()
-      streamIter.filter(current => {
-        !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
-      })
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/17072386/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5b8b70e..61d5f20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -395,6 +395,18 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll with SQLTestUtils {
     )
   }
 
+  test("left semi greater than predicate and equal operator") {
+    checkAnswer(
+      sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b 
and x.a >= y.a + 2"),
+      Seq(Row(3, 1), Row(3, 2))
+    )
+
+    checkAnswer(
+      sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a 
and x.a >= y.b + 1"),
+      Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))
+    )
+  }
+
   test("index into array of arrays") {
     checkAnswer(
       sql(

http://git-wip-us.apache.org/repos/asf/spark/blob/17072386/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
new file mode 100644
index 0000000..927e85a
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+
+
+class SemiJoinSuite extends SparkPlanTest{
+  val left = Seq(
+    (1, 2.0),
+    (1, 2.0),
+    (2, 1.0),
+    (2, 1.0),
+    (3, 3.0)
+  ).toDF("a", "b")
+
+  val right = Seq(
+    (2, 3.0),
+    (2, 3.0),
+    (3, 2.0),
+    (4, 1.0)
+  ).toDF("c", "d")
+
+  val leftKeys: List[Expression] = 'a :: Nil
+  val rightKeys: List[Expression] = 'c :: Nil
+  val condition = Some(LessThan('b, 'd))
+
+  test("left semi join hash") {
+    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
+      LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
+      Seq(
+        (2, 1.0),
+        (2, 1.0)
+      ).map(Row.fromTuple))
+  }
+
+  test("left semi join BNL") {
+    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
+      LeftSemiJoinBNL(left, right, condition),
+      Seq(
+        (1, 2.0),
+        (1, 2.0),
+        (2, 1.0),
+        (2, 1.0)
+      ).map(Row.fromTuple))
+  }
+
+  test("broadcast left semi join hash") {
+    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
+      BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
+      Seq(
+        (2, 1.0),
+        (2, 1.0)
+      ).map(Row.fromTuple))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to