Repository: spark
Updated Branches:
  refs/heads/master 97594c29b -> 94b24b84a


[SPARK-17806] [SQL] fix bug in join key rewritten in HashJoin

## What changes were proposed in this pull request?

In HashJoin, we try to rewrite the join key as Long to improve the performance 
of finding a match. The rewriting part is not well tested, has a bug that could 
cause wrong result when there are at least three integral columns in the 
joining key also the total length of the key exceed 8 bytes.

## How was this patch tested?

Added unit test to covering the rewriting with different number of columns and 
different data types. Manually test the reported case and confirmed that this 
PR fix the bug.

Author: Davies Liu <dav...@databricks.com>

Closes #15390 from davies/rewrite_key.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94b24b84
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94b24b84
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94b24b84

Branch: refs/heads/master
Commit: 94b24b84a666517e31e9c9d693f92d9bbfd7f9ad
Parents: 97594c2
Author: Davies Liu <dav...@databricks.com>
Authored: Fri Oct 7 15:03:47 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Fri Oct 7 15:03:47 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/joins/HashJoin.scala    | 65 ++++++++++----------
 .../execution/joins/BroadcastJoinSuite.scala    | 47 ++++++++++++++
 2 files changed, 79 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/94b24b84/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 8ddac19..05c5e2f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -63,45 +63,16 @@ trait HashJoin {
   protected lazy val (buildKeys, streamedKeys) = {
     require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
       "Join keys from two sides should have same types")
-    val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, 
left.output))
-    val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, 
right.output))
+    val lkeys = 
HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, 
left.output))
+    val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
+      .map(BindReferences.bindReference(_, right.output))
     buildSide match {
       case BuildLeft => (lkeys, rkeys)
       case BuildRight => (rkeys, lkeys)
     }
   }
 
-  /**
-   * Try to rewrite the key as LongType so we can use getLong(), if they key 
can fit with a long.
-   *
-   * If not, returns the original expressions.
-   */
-  private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
-    var keyExpr: Expression = null
-    var width = 0
-    keys.foreach { e =>
-      e.dataType match {
-        case dt: IntegralType if dt.defaultSize <= 8 - width =>
-          if (width == 0) {
-            if (e.dataType != LongType) {
-              keyExpr = Cast(e, LongType)
-            } else {
-              keyExpr = e
-            }
-            width = dt.defaultSize
-          } else {
-            val bits = dt.defaultSize * 8
-            keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
-              BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
-            width -= bits
-          }
-        // TODO: support BooleanType, DateType and TimestampType
-        case other =>
-          return keys
-      }
-    }
-    keyExpr :: Nil
-  }
+
 
   protected def buildSideKeyGenerator(): Projection =
     UnsafeProjection.create(buildKeys)
@@ -247,3 +218,31 @@ trait HashJoin {
     }
   }
 }
+
+object HashJoin {
+  /**
+   * Try to rewrite the key as LongType so we can use getLong(), if they key 
can fit with a long.
+   *
+   * If not, returns the original expressions.
+   */
+  private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+    assert(keys.nonEmpty)
+    // TODO: support BooleanType, DateType and TimestampType
+    if (keys.exists(!_.dataType.isInstanceOf[IntegralType])
+      || keys.map(_.dataType.defaultSize).sum > 8) {
+      return keys
+    }
+
+    var keyExpr: Expression = if (keys.head.dataType != LongType) {
+      Cast(keys.head, LongType)
+    } else {
+      keys.head
+    }
+    keys.tail.foreach { e =>
+      val bits = e.dataType.defaultSize * 8
+      keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
+        BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+    }
+    keyExpr :: Nil
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/94b24b84/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 97adffa..83db81e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -21,11 +21,13 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, 
Literal, ShiftLeft}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{LongType, ShortType}
 
 /**
  * Test various broadcast join operators.
@@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
       cases.foreach(assertBroadcastJoin)
     }
   }
+
+  test("join key rewritten") {
+    val l = Literal(1L)
+    val i = Literal(2)
+    val s = Literal.create(3, ShortType)
+    val ss = Literal("hello")
+
+    assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
+      BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
+        BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(
+        BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+          BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+        Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(
+        BitwiseOr(ShiftLeft(
+          BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+            BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+          Literal(16)),
+          BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+        Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
+      s :: s :: s :: s :: s :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to