This is an automated email from the ASF dual-hosted git repository.
ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fbf096ecdecb [SPARK-55411][SQL] SPJ may throw
ArrayIndexOutOfBoundsException when join keys are less than cluster keys
fbf096ecdecb is described below
commit fbf096ecdecb359c4711b0144dd623ff3b449600
Author: Cheng Pan <[email protected]>
AuthorDate: Tue Feb 10 16:43:43 2026 +0100
[SPARK-55411][SQL] SPJ may throw ArrayIndexOutOfBoundsException when join
keys are less than cluster keys
### What changes were proposed in this pull request?
Fix a `java.lang.ArrayIndexOutOfBoundsException` when
`spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true`,
by correcting the `expression`(should pass the full partition expression
instead of the projected one) passed to `KeyGroupedPartitioning#project`.
Also, fix a test code issue, change the calculation result of
`BucketTransform` defined at `InMemoryBaseTable.scala` to match
`BucketFunctions` defined at `transformFunctions.scala` (thanks peter-toth for
pointing this out!)
### Why are the changes needed?
It's a bug fix.
### Does this PR introduce _any_ user-facing change?
Some queries that failed when
`spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true`
now run normally.
### How was this patch tested?
New UT is added, previously it failed with
`ArrayIndexOutOfBoundsException`, now passed.
```
$ build/sbt "sql/testOnly *KeyGroupedPartitioningSuite -- -z SPARK=55411"
...
[info] - bug *** FAILED *** (1 second, 884 milliseconds)
[info] java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds
for length 1
[info] at
scala.collection.immutable.ArraySeq$ofRef.apply(ArraySeq.scala:331)
[info] at
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1(partitioning.scala:471)
[info] at
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1$adapted(partitioning.scala:471)
[info] at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
[info] at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
[info] at
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.project(partitioning.scala:471)
[info] at
org.apache.spark.sql.execution.KeyGroupedPartitionedScan.$anonfun$getOutputKeyGroupedPartitioning$5(KeyGroupedPartitionedScan.scala:58)
...
```
UTs affected by `bucket()` calculate logic change are tuned.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54182 from pan3793/spj-subset-joinkey-bug.
Authored-by: Cheng Pan <[email protected]>
Signed-off-by: Peter Toth <[email protected]>
---
.../sql/connector/catalog/InMemoryBaseTable.scala | 27 +++++++++-----
.../sql/execution/KeyGroupedPartitionedScan.scala | 8 ++---
.../connector/KeyGroupedPartitioningSuite.scala | 42 ++++++++++++++++++++--
.../spark/sql/connector/MetadataColumnSuite.scala | 18 +++++-----
.../catalog/functions/transformFunctions.scala | 3 +-
5 files changed, 73 insertions(+), 25 deletions(-)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index ebb4eef80f15..407d592f8219 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -236,15 +236,26 @@ abstract class InMemoryBaseTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported
argument(s) type - ($v, $t)")
}
+ // the result should be consistent with BucketFunctions defined at
transformFunctions.scala
case BucketTransform(numBuckets, cols, _) =>
- val valueTypePairs = cols.map(col => extractor(col.fieldNames,
cleanedSchema, row))
- var valueHashCode = 0
- valueTypePairs.foreach( pair =>
- if ( pair._1 != null) valueHashCode += pair._1.hashCode()
- )
- var dataTypeHashCode = 0
- valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
- ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) %
numBuckets
+ val hash: Long = cols.foldLeft(0L) { (acc, col) =>
+ val valueHash = extractor(col.fieldNames, cleanedSchema, row) match {
+ case (value: Byte, _: ByteType) => value.toLong
+ case (value: Short, _: ShortType) => value.toLong
+ case (value: Int, _: IntegerType) => value.toLong
+ case (value: Long, _: LongType) => value
+ case (value: Long, _: TimestampType) => value
+ case (value: Long, _: TimestampNTZType) => value
+ case (value: UTF8String, _: StringType) =>
+ value.hashCode.toLong
+ case (value: Array[Byte], BinaryType) =>
+ util.Arrays.hashCode(value).toLong
+ case (v, t) =>
+ throw new IllegalArgumentException(s"Match: unsupported
argument(s) type - ($v, $t)")
+ }
+ acc + valueHash
+ }
+ Math.floorMod(hash, numBuckets)
case NamedTransform("truncate", Seq(ref: NamedReference, length:
V2Literal[_])) =>
extractor(ref.fieldNames, cleanedSchema, row) match {
case (str: UTF8String, StringType) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
index 10a6aaa2e185..cac4a9bc852f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
@@ -34,7 +34,7 @@ trait KeyGroupedPartitionedScan[T] {
def getOutputKeyGroupedPartitioning(
basePartitioning: KeyGroupedPartitioning,
spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = {
- val expressions = spjParams.joinKeyPositions match {
+ val projectedExpressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) =>
projectionPositions.map(i => basePartitioning.expressions(i))
case _ => basePartitioning.expressions
@@ -52,16 +52,16 @@ trait KeyGroupedPartitionedScan[T] {
case Some(projectionPositions) =>
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
- expressions.map(_.dataType))
+ projectedExpressions.map(_.dataType))
basePartitioning.partitionValues.map { r =>
- val projectedRow = KeyGroupedPartitioning.project(expressions,
+ val projectedRow =
KeyGroupedPartitioning.project(basePartitioning.expressions,
projectionPositions, r)
internalRowComparableWrapperFactory(projectedRow)
}.distinct.map(_.row)
case _ => basePartitioning.partitionValues
}
}
- basePartitioning.copy(expressions = expressions, numPartitions =
newPartValues.length,
+ basePartitioning.copy(expressions = projectedExpressions, numPartitions =
newPartValues.length,
partitionValues = newPartValues)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 7c07d08d80af..8cd55304d71c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -124,7 +124,7 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
// Has exactly one partition.
- val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
+ val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v)))
checkQueryPlan(df, distribution,
physical.KeyGroupedPartitioning(distribution.clustering, 1,
partitionValues, partitionValues))
}
@@ -2798,8 +2798,6 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
}
test("SPARK-54439: KeyGroupedPartitioning with transform and join key size
mismatch") {
- // Do not use `bucket()` in "one side partition" tests as its
implementation in
- // `InMemoryBaseTable` conflicts with `BucketFunction`
val items_partitions = Array(years("arrive_time"))
createTable(items, itemsColumns, items_partitions)
@@ -2841,4 +2839,42 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
}
assert(metrics("number of rows read") == "3")
}
+
+ test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
+ "are less than cluster keys") {
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
"false",
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
"true") {
+
+ val customers_partitions = Array(identity("customer_name"), bucket(4,
"customer_id"))
+ createTable(customers, customersColumns, customers_partitions)
+ sql(s"INSERT INTO testcat.ns.$customers VALUES " +
+ s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)")
+
+ createTable(orders, ordersColumns, Array.empty)
+ sql(s"INSERT INTO testcat.ns.$orders VALUES " +
+ s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50,
3)")
+
+ val df = sql(
+ s"""${selectWithMergeJoinHint("c", "o")}
+ |customer_name, customer_age, order_amount
+ |FROM testcat.ns.$customers c JOIN testcat.ns.$orders o
+ |ON c.customer_id = o.customer_id ORDER BY c.customer_id,
order_amount
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.length == 1)
+
+ checkAnswer(df, Seq(
+ Row("aaa", 10, 100.0),
+ Row("aaa", 10, 200.0),
+ Row("bbb", 20, 150.0),
+ Row("bbb", 20, 250.0),
+ Row("bbb", 20, 350.0),
+ Row("ccc", 30, 400.50)))
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
index 3bfd57e867c0..fe338175ec88 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
@@ -42,7 +42,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).select("id", "data", "index",
"_partition")
Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"),
Row(3, "c", 0, "1/3")))
+ checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"),
Row(3, "c", 0, "3/3")))
}
}
}
@@ -56,7 +56,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).select("index", "data", "_partition")
Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1,
"a", "3/1")))
+ checkAnswer(query, Seq(Row(3, "c", "3/3"), Row(2, "b", "2/2"), Row(1,
"a", "1/1")))
}
}
}
@@ -125,7 +125,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
checkAnswer(
dfQuery,
- Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0,
"1/3"))
+ Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0,
"3/3"))
)
}
}
@@ -135,7 +135,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
prepareTable()
checkAnswer(
spark.table(tbl).select("id", "data").select("index", "_partition"),
- Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3"))
+ Seq(Row(0, "1/1"), Row(0, "2/2"), Row(0, "3/3"))
)
}
}
@@ -160,7 +160,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).where("id > 1").select("id", "data",
"index", "_partition")
Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+ checkAnswer(query, Seq(Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
@@ -172,7 +172,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).orderBy("id").select("id", "data",
"index", "_partition")
Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"),
Row(3, "c", 0, "1/3")))
+ checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"),
Row(3, "c", 0, "3/3")))
}
}
}
@@ -186,7 +186,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
.select("id", "data", "index", "_partition")
Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"),
Row(3, "c", 0, "1/3")))
+ checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"),
Row(3, "c", 0, "3/3")))
}
}
}
@@ -201,7 +201,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")
Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"),
Row(3, "c", 0, "1/3")))
+ checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"),
Row(3, "c", 0, "3/3")))
}
// Metadata columns are propagated through SubqueryAlias even if child
is not a leaf node.
@@ -394,7 +394,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
withTable(tbl) {
sql(s"CREATE TABLE $tbl (id bigint, data char(1)) PARTITIONED BY
(bucket(4, id), id)")
sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
- val expected = Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"),
Row(3, "c", 0, "1/3"))
+ val expected = Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"),
Row(3, "c", 0, "3/3"))
// Unqualified column access
checkAnswer(sql(s"SELECT id, data, index, _partition FROM $tbl"),
expected)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index b82cc2392e1f..ed2f81d7e8d6 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -84,6 +84,7 @@ object UnboundBucketFunction extends UnboundFunction {
override def name(): String = "bucket"
}
+// the result should be consistent with BucketTransform defined at
InMemoryBaseTable.scala
object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int,
Int] {
override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
override def resultType(): DataType = IntegerType
@@ -91,7 +92,7 @@ object BucketFunction extends ScalarFunction[Int] with
ReducibleFunction[Int, In
override def canonicalName(): String = name()
override def toString: String = name()
override def produceResult(input: InternalRow): Int = {
- (input.getLong(1) % input.getInt(0)).toInt
+ Math.floorMod(input.getLong(1), input.getInt(0))
}
override def reducer(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]