This is an automated email from the ASF dual-hosted git repository.

ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fbf096ecdecb [SPARK-55411][SQL] SPJ may throw 
ArrayIndexOutOfBoundsException when join keys are less than cluster keys
fbf096ecdecb is described below

commit fbf096ecdecb359c4711b0144dd623ff3b449600
Author: Cheng Pan <[email protected]>
AuthorDate: Tue Feb 10 16:43:43 2026 +0100

    [SPARK-55411][SQL] SPJ may throw ArrayIndexOutOfBoundsException when join 
keys are less than cluster keys
    
    ### What changes were proposed in this pull request?
    
    Fix a `java.lang.ArrayIndexOutOfBoundsException` when 
`spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true`,
 by correcting the `expression`(should pass the full partition expression 
instead of the projected one) passed to `KeyGroupedPartitioning#project`.
    
    Also, fix a test code issue, change the calculation result of 
`BucketTransform` defined at `InMemoryBaseTable.scala` to match 
`BucketFunctions` defined at `transformFunctions.scala` (thanks peter-toth for 
pointing this out!)
    
    ### Why are the changes needed?
    
    It's a bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Some queries that failed when 
`spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true`
 now run normally.
    
    ### How was this patch tested?
    
    New UT is added, previously it failed with 
`ArrayIndexOutOfBoundsException`, now passed.
    
    ```
    $ build/sbt "sql/testOnly *KeyGroupedPartitioningSuite -- -z SPARK=55411"
    ...
    [info] - bug *** FAILED *** (1 second, 884 milliseconds)
    [info]   java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds 
for length 1
    [info]   at 
scala.collection.immutable.ArraySeq$ofRef.apply(ArraySeq.scala:331)
    [info]   at 
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1(partitioning.scala:471)
    [info]   at 
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1$adapted(partitioning.scala:471)
    [info]   at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
    [info]   at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
    [info]   at 
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.project(partitioning.scala:471)
    [info]   at 
org.apache.spark.sql.execution.KeyGroupedPartitionedScan.$anonfun$getOutputKeyGroupedPartitioning$5(KeyGroupedPartitionedScan.scala:58)
    ...
    ```
    
    UTs affected by `bucket()` calculate logic change are tuned.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54182 from pan3793/spj-subset-joinkey-bug.
    
    Authored-by: Cheng Pan <[email protected]>
    Signed-off-by: Peter Toth <[email protected]>
---
 .../sql/connector/catalog/InMemoryBaseTable.scala  | 27 +++++++++-----
 .../sql/execution/KeyGroupedPartitionedScan.scala  |  8 ++---
 .../connector/KeyGroupedPartitioningSuite.scala    | 42 ++++++++++++++++++++--
 .../spark/sql/connector/MetadataColumnSuite.scala  | 18 +++++-----
 .../catalog/functions/transformFunctions.scala     |  3 +-
 5 files changed, 73 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index ebb4eef80f15..407d592f8219 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -236,15 +236,26 @@ abstract class InMemoryBaseTable(
           case (v, t) =>
             throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
         }
+      // the result should be consistent with BucketFunctions defined at 
transformFunctions.scala
       case BucketTransform(numBuckets, cols, _) =>
-        val valueTypePairs = cols.map(col => extractor(col.fieldNames, 
cleanedSchema, row))
-        var valueHashCode = 0
-        valueTypePairs.foreach( pair =>
-          if ( pair._1 != null) valueHashCode += pair._1.hashCode()
-        )
-        var dataTypeHashCode = 0
-        valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
-        ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % 
numBuckets
+        val hash: Long = cols.foldLeft(0L) { (acc, col) =>
+          val valueHash = extractor(col.fieldNames, cleanedSchema, row) match {
+            case (value: Byte, _: ByteType) => value.toLong
+            case (value: Short, _: ShortType) => value.toLong
+            case (value: Int, _: IntegerType) => value.toLong
+            case (value: Long, _: LongType) => value
+            case (value: Long, _: TimestampType) => value
+            case (value: Long, _: TimestampNTZType) => value
+            case (value: UTF8String, _: StringType) =>
+              value.hashCode.toLong
+            case (value: Array[Byte], BinaryType) =>
+              util.Arrays.hashCode(value).toLong
+            case (v, t) =>
+              throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
+          }
+          acc + valueHash
+        }
+        Math.floorMod(hash, numBuckets)
       case NamedTransform("truncate", Seq(ref: NamedReference, length: 
V2Literal[_])) =>
         extractor(ref.fieldNames, cleanedSchema, row) match {
           case (str: UTF8String, StringType) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
index 10a6aaa2e185..cac4a9bc852f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
@@ -34,7 +34,7 @@ trait KeyGroupedPartitionedScan[T] {
   def getOutputKeyGroupedPartitioning(
       basePartitioning: KeyGroupedPartitioning,
       spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = {
-    val expressions = spjParams.joinKeyPositions match {
+    val projectedExpressions = spjParams.joinKeyPositions match {
       case Some(projectionPositions) =>
         projectionPositions.map(i => basePartitioning.expressions(i))
       case _ => basePartitioning.expressions
@@ -52,16 +52,16 @@ trait KeyGroupedPartitionedScan[T] {
           case Some(projectionPositions) =>
             val internalRowComparableWrapperFactory =
               
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
-                expressions.map(_.dataType))
+                projectedExpressions.map(_.dataType))
             basePartitioning.partitionValues.map { r =>
-            val projectedRow = KeyGroupedPartitioning.project(expressions,
+            val projectedRow = 
KeyGroupedPartitioning.project(basePartitioning.expressions,
               projectionPositions, r)
             internalRowComparableWrapperFactory(projectedRow)
           }.distinct.map(_.row)
           case _ => basePartitioning.partitionValues
         }
     }
-    basePartitioning.copy(expressions = expressions, numPartitions = 
newPartValues.length,
+    basePartitioning.copy(expressions = projectedExpressions, numPartitions = 
newPartValues.length,
       partitionValues = newPartValues)
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 7c07d08d80af..8cd55304d71c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -124,7 +124,7 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
 
     // Has exactly one partition.
-    val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
+    val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v)))
     checkQueryPlan(df, distribution,
       physical.KeyGroupedPartitioning(distribution.clustering, 1, 
partitionValues, partitionValues))
   }
@@ -2798,8 +2798,6 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
   }
 
   test("SPARK-54439: KeyGroupedPartitioning with transform and join key size 
mismatch") {
-    // Do not use `bucket()` in "one side partition" tests as its 
implementation in
-    // `InMemoryBaseTable` conflicts with `BucketFunction`
     val items_partitions = Array(years("arrive_time"))
     createTable(items, itemsColumns, items_partitions)
 
@@ -2841,4 +2839,42 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     }
     assert(metrics("number of rows read") == "3")
   }
+
+  test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
+    "are less than cluster keys") {
+    withSQLConf(
+      SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+      SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+      SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+      SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+      SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true") {
+
+      val customers_partitions = Array(identity("customer_name"), bucket(4, 
"customer_id"))
+      createTable(customers, customersColumns, customers_partitions)
+      sql(s"INSERT INTO testcat.ns.$customers VALUES " +
+        s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)")
+
+      createTable(orders, ordersColumns, Array.empty)
+      sql(s"INSERT INTO testcat.ns.$orders VALUES " +
+        s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 
3)")
+
+      val df = sql(
+        s"""${selectWithMergeJoinHint("c", "o")}
+           |customer_name, customer_age, order_amount
+           |FROM testcat.ns.$customers c JOIN testcat.ns.$orders o
+           |ON c.customer_id = o.customer_id ORDER BY c.customer_id, 
order_amount
+           |""".stripMargin)
+
+      val shuffles = collectShuffles(df.queryExecution.executedPlan)
+      assert(shuffles.length == 1)
+
+      checkAnswer(df, Seq(
+        Row("aaa", 10, 100.0),
+        Row("aaa", 10, 200.0),
+        Row("bbb", 20, 150.0),
+        Row("bbb", 20, 250.0),
+        Row("bbb", 20, 350.0),
+        Row("ccc", 30, 400.50)))
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
index 3bfd57e867c0..fe338175ec88 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
@@ -42,7 +42,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).select("id", "data", "index", 
"_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -56,7 +56,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).select("index", "data", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, 
"a", "3/1")))
+        checkAnswer(query, Seq(Row(3, "c", "3/3"), Row(2, "b", "2/2"), Row(1, 
"a", "1/1")))
       }
     }
   }
@@ -125,7 +125,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
 
       checkAnswer(
         dfQuery,
-        Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, 
"1/3"))
+        Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, 
"3/3"))
       )
     }
   }
@@ -135,7 +135,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       prepareTable()
       checkAnswer(
         spark.table(tbl).select("id", "data").select("index", "_partition"),
-        Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3"))
+        Seq(Row(0, "1/1"), Row(0, "2/2"), Row(0, "3/3"))
       )
     }
   }
@@ -160,7 +160,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", 
"index", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -172,7 +172,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", 
"index", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -186,7 +186,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
         .select("id", "data", "index", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -201,7 +201,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
         s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
 
       // Metadata columns are propagated through SubqueryAlias even if child 
is not a leaf node.
@@ -394,7 +394,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       withTable(tbl) {
         sql(s"CREATE TABLE $tbl (id bigint, data char(1)) PARTITIONED BY 
(bucket(4, id), id)")
         sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
-        val expected = Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3"))
+        val expected = Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3"))
 
         // Unqualified column access
         checkAnswer(sql(s"SELECT id, data, index, _partition FROM $tbl"), 
expected)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index b82cc2392e1f..ed2f81d7e8d6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -84,6 +84,7 @@ object UnboundBucketFunction extends UnboundFunction {
   override def name(): String = "bucket"
 }
 
+// the result should be consistent with BucketTransform defined at 
InMemoryBaseTable.scala
 object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, 
Int] {
   override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
   override def resultType(): DataType = IntegerType
@@ -91,7 +92,7 @@ object BucketFunction extends ScalarFunction[Int] with 
ReducibleFunction[Int, In
   override def canonicalName(): String = name()
   override def toString: String = name()
   override def produceResult(input: InternalRow): Int = {
-    (input.getLong(1) % input.getInt(0)).toInt
+    Math.floorMod(input.getLong(1), input.getInt(0))
   }
 
   override def reducer(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to