This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 29eca8c [SPARK-38325][SQL] ANSI mode: avoid potential runtime error
in HashJoin.extractKeyExprAt()
29eca8c is described below
commit 29eca8c87f4e8c19c0380f7c30668fd88edee573
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri Feb 25 17:11:15 2022 +0800
[SPARK-38325][SQL] ANSI mode: avoid potential runtime error in
HashJoin.extractKeyExprAt()
### What changes were proposed in this pull request?
SubqueryBroadcastExec retrieves the partition key from the broadcast
results based on the type of HashedRelation returned. If the key is packed
inside a Long, we extract it through bitwise operations and cast it as
Byte/Short/Int if necessary.
The casting here can cause a potential runtime error. This PR is to fix it.
### Why are the changes needed?
Bug fix
### Does this PR introduce _any_ user-facing change?
Yes, avoid potential runtime error in dynamic pruning under ANSI mode
### How was this patch tested?
UT
Closes #35659 from gengliangwang/fixHashJoin.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../spark/sql/execution/joins/HashJoin.scala | 27 +++++++++++++++++-----
.../sql/execution/joins/HashedRelationSuite.scala | 22 +++++++++++-------
2 files changed, 35 insertions(+), 14 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 0e8bb84..4595ea0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -705,6 +705,13 @@ trait HashJoin extends JoinCodegenSupport {
}
object HashJoin extends CastSupport with SQLConfHelper {
+
+ private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
+ // TODO: support BooleanType, DateType and TimestampType
+ keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
+ keys.map(_.dataType.defaultSize).sum <= 8
+ }
+
/**
* Try to rewrite the key as LongType so we can use getLong(), if they key
can fit with a long.
*
@@ -712,9 +719,7 @@ object HashJoin extends CastSupport with SQLConfHelper {
*/
def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
assert(keys.nonEmpty)
- // TODO: support BooleanType, DateType and TimestampType
- if (keys.exists(!_.dataType.isInstanceOf[IntegralType])
- || keys.map(_.dataType.defaultSize).sum > 8) {
+ if (!canRewriteAsLongType(keys)) {
return keys
}
@@ -736,18 +741,28 @@ object HashJoin extends CastSupport with SQLConfHelper {
* determine the number of bits to shift
*/
def extractKeyExprAt(keys: Seq[Expression], index: Int): Expression = {
+ assert(canRewriteAsLongType(keys))
// jump over keys that have a higher index value than the required key
if (keys.size == 1) {
assert(index == 0)
- cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
+ Cast(
+ child = BoundReference(0, LongType, nullable = false),
+ dataType = keys(index).dataType,
+ timeZoneId = Option(conf.sessionLocalTimeZone),
+ ansiEnabled = false)
} else {
val shiftedBits =
keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
// build the schema for unpacking the required key
- cast(BitwiseAnd(
+ val castChild = BitwiseAnd(
ShiftRightUnsigned(BoundReference(0, LongType, nullable = false),
Literal(shiftedBits)),
- Literal(mask)), keys(index).dataType)
+ Literal(mask))
+ Cast(
+ child = castChild,
+ dataType = keys(index).dataType,
+ timeZoneId = Option(conf.sessionLocalTimeZone),
+ ansiEnabled = false)
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index b8ffc47..d5b7ed6 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.memory.{TaskMemoryManager,
UnifiedMemoryManager}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.map.BytesToBytesMap
@@ -610,14 +611,19 @@ class HashedRelationSuite extends SharedSparkSession {
val keys = Seq(BoundReference(0, ByteType, false),
BoundReference(1, IntegerType, false),
BoundReference(2, ShortType, false))
- val packed = HashJoin.rewriteKeyExpr(keys)
- val unsafeProj = UnsafeProjection.create(packed)
- val packedKeys = unsafeProj(row)
-
- Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i,
dt) =>
- val key = HashJoin.extractKeyExprAt(keys, i)
- val proj = UnsafeProjection.create(key)
- assert(proj(packedKeys).get(0, dt) == -i - 1)
+ // Rewrite and exacting key expressions should not cause exception when
ANSI mode is on.
+ Seq("false", "true").foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) {
+ val packed = HashJoin.rewriteKeyExpr(keys)
+ val unsafeProj = UnsafeProjection.create(packed)
+ val packedKeys = unsafeProj(row)
+
+ Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case
(i, dt) =>
+ val key = HashJoin.extractKeyExprAt(keys, i)
+ val proj = UnsafeProjection.create(key)
+ assert(proj(packedKeys).get(0, dt) == -i - 1)
+ }
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]