Repository: spark
Updated Branches:
  refs/heads/branch-2.0 a84d8ef37 -> 6d056c168


[SPARK-17806] [SQL] fix bug in join key rewritten in HashJoin

## What changes were proposed in this pull request?

In HashJoin, we try to rewrite the join key as Long to improve the performance 
of finding a match. The rewriting part is not well tested, has a bug that could 
cause wrong result when there are at least three integral columns in the 
joining key also the total length of the key exceed 8 bytes.

## How was this patch tested?

Added unit test to covering the rewriting with different number of columns and 
different data types. Manually test the reported case and confirmed that this 
PR fix the bug.

Author: Davies Liu <[email protected]>

Closes #15390 from davies/rewrite_key.

(cherry picked from commit 94b24b84a666517e31e9c9d693f92d9bbfd7f9ad)
Signed-off-by: Davies Liu <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d056c16
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d056c16
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d056c16

Branch: refs/heads/branch-2.0
Commit: 6d056c168c45d2decf5ffbb96d59623d52ed8490
Parents: a84d8ef
Author: Davies Liu <[email protected]>
Authored: Fri Oct 7 15:03:47 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Fri Oct 7 15:03:58 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/joins/HashJoin.scala    | 65 ++++++++++----------
 .../execution/joins/BroadcastJoinSuite.scala    | 47 ++++++++++++++
 2 files changed, 79 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6d056c16/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index d46a804..d11f7e6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -63,45 +63,16 @@ trait HashJoin {
   protected lazy val (buildKeys, streamedKeys) = {
     require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
       "Join keys from two sides should have same types")
-    val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, 
left.output))
-    val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, 
right.output))
+    val lkeys = 
HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, 
left.output))
+    val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
+      .map(BindReferences.bindReference(_, right.output))
     buildSide match {
       case BuildLeft => (lkeys, rkeys)
       case BuildRight => (rkeys, lkeys)
     }
   }
 
-  /**
-   * Try to rewrite the key as LongType so we can use getLong(), if they key 
can fit with a long.
-   *
-   * If not, returns the original expressions.
-   */
-  private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
-    var keyExpr: Expression = null
-    var width = 0
-    keys.foreach { e =>
-      e.dataType match {
-        case dt: IntegralType if dt.defaultSize <= 8 - width =>
-          if (width == 0) {
-            if (e.dataType != LongType) {
-              keyExpr = Cast(e, LongType)
-            } else {
-              keyExpr = e
-            }
-            width = dt.defaultSize
-          } else {
-            val bits = dt.defaultSize * 8
-            keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
-              BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
-            width -= bits
-          }
-        // TODO: support BooleanType, DateType and TimestampType
-        case other =>
-          return keys
-      }
-    }
-    keyExpr :: Nil
-  }
+
 
   protected def buildSideKeyGenerator(): Projection =
     UnsafeProjection.create(buildKeys)
@@ -247,3 +218,31 @@ trait HashJoin {
     }
   }
 }
+
+object HashJoin {
+  /**
+   * Try to rewrite the key as LongType so we can use getLong(), if they key 
can fit with a long.
+   *
+   * If not, returns the original expressions.
+   */
+  private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+    assert(keys.nonEmpty)
+    // TODO: support BooleanType, DateType and TimestampType
+    if (keys.exists(!_.dataType.isInstanceOf[IntegralType])
+      || keys.map(_.dataType.defaultSize).sum > 8) {
+      return keys
+    }
+
+    var keyExpr: Expression = if (keys.head.dataType != LongType) {
+      Cast(keys.head, LongType)
+    } else {
+      keys.head
+    }
+    keys.tail.foreach { e =>
+      val bits = e.dataType.defaultSize * 8
+      keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
+        BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+    }
+    keyExpr :: Nil
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6d056c16/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index b679e3b..c22c106 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -23,11 +23,13 @@ import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, 
Literal, ShiftLeft}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{LongType, ShortType}
 
 /**
  * Test various broadcast join operators.
@@ -156,4 +158,49 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
       cases.foreach(assertBroadcastJoin)
     }
   }
+
+  test("join key rewritten") {
+    val l = Literal(1L)
+    val i = Literal(2)
+    val s = Literal.create(3, ShortType)
+    val ss = Literal("hello")
+
+    assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
+      BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
+        BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(
+        BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+          BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+        Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(
+        BitwiseOr(ShiftLeft(
+          BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+            BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+          Literal(16)),
+          BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+        Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
+      s :: s :: s :: s :: s :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to