This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 29eca8c  [SPARK-38325][SQL] ANSI mode: avoid potential runtime error 
in HashJoin.extractKeyExprAt()
29eca8c is described below

commit 29eca8c87f4e8c19c0380f7c30668fd88edee573
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Fri Feb 25 17:11:15 2022 +0800

    [SPARK-38325][SQL] ANSI mode: avoid potential runtime error in 
HashJoin.extractKeyExprAt()
    
    ### What changes were proposed in this pull request?
    
    SubqueryBroadcastExec retrieves the partition key from the broadcast 
results based on the type of HashedRelation returned. If the key is packed 
inside a Long, we extract it through bitwise operations and cast it as 
Byte/Short/Int if necessary.
    
    The casting here can cause a potential runtime error. This PR is to fix it.
    
    ### Why are the changes needed?
    
    Bug fix
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, avoid potential runtime error in dynamic pruning under ANSI mode
    
    ### How was this patch tested?
    
    UT
    
    Closes #35659 from gengliangwang/fixHashJoin.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/execution/joins/HashJoin.scala       | 27 +++++++++++++++++-----
 .../sql/execution/joins/HashedRelationSuite.scala  | 22 +++++++++++-------
 2 files changed, 35 insertions(+), 14 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 0e8bb84..4595ea0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -705,6 +705,13 @@ trait HashJoin extends JoinCodegenSupport {
 }
 
 object HashJoin extends CastSupport with SQLConfHelper {
+
+  private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
+    // TODO: support BooleanType, DateType and TimestampType
+    keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
+      keys.map(_.dataType.defaultSize).sum <= 8
+  }
+
   /**
    * Try to rewrite the key as LongType so we can use getLong(), if they key 
can fit with a long.
    *
@@ -712,9 +719,7 @@ object HashJoin extends CastSupport with SQLConfHelper {
    */
   def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
     assert(keys.nonEmpty)
-    // TODO: support BooleanType, DateType and TimestampType
-    if (keys.exists(!_.dataType.isInstanceOf[IntegralType])
-      || keys.map(_.dataType.defaultSize).sum > 8) {
+    if (!canRewriteAsLongType(keys)) {
       return keys
     }
 
@@ -736,18 +741,28 @@ object HashJoin extends CastSupport with SQLConfHelper {
    * determine the number of bits to shift
    */
   def extractKeyExprAt(keys: Seq[Expression], index: Int): Expression = {
+    assert(canRewriteAsLongType(keys))
     // jump over keys that have a higher index value than the required key
     if (keys.size == 1) {
       assert(index == 0)
-      cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
+      Cast(
+        child = BoundReference(0, LongType, nullable = false),
+        dataType = keys(index).dataType,
+        timeZoneId = Option(conf.sessionLocalTimeZone),
+        ansiEnabled = false)
     } else {
       val shiftedBits =
         keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
       val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
       // build the schema for unpacking the required key
-      cast(BitwiseAnd(
+      val castChild = BitwiseAnd(
         ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), 
Literal(shiftedBits)),
-        Literal(mask)), keys(index).dataType)
+        Literal(mask))
+      Cast(
+        child = castChild,
+        dataType = keys(index).dataType,
+        timeZoneId = Option(conf.sessionLocalTimeZone),
+        ansiEnabled = false)
     }
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index b8ffc47..d5b7ed6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.memory.{TaskMemoryManager, 
UnifiedMemoryManager}
 import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.map.BytesToBytesMap
@@ -610,14 +611,19 @@ class HashedRelationSuite extends SharedSparkSession {
     val keys = Seq(BoundReference(0, ByteType, false),
       BoundReference(1, IntegerType, false),
       BoundReference(2, ShortType, false))
-    val packed = HashJoin.rewriteKeyExpr(keys)
-    val unsafeProj = UnsafeProjection.create(packed)
-    val packedKeys = unsafeProj(row)
-
-    Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, 
dt) =>
-      val key = HashJoin.extractKeyExprAt(keys, i)
-      val proj = UnsafeProjection.create(key)
-      assert(proj(packedKeys).get(0, dt) == -i - 1)
+    // Rewrite and exacting key expressions should not cause exception when 
ANSI mode is on.
+    Seq("false", "true").foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) {
+        val packed = HashJoin.rewriteKeyExpr(keys)
+        val unsafeProj = UnsafeProjection.create(packed)
+        val packedKeys = unsafeProj(row)
+
+        Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case 
(i, dt) =>
+          val key = HashJoin.extractKeyExprAt(keys, i)
+          val proj = UnsafeProjection.create(key)
+          assert(proj(packedKeys).get(0, dt) == -i - 1)
+        }
+      }
     }
   }
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to