This is an automated email from the ASF dual-hosted git repository.

ptoth pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 1a3e9b777893 [SPARK-55411][SQL][4.0] SPJ may throw 
ArrayIndexOutOfBoundsException when join keys are less than cluster keys
1a3e9b777893 is described below

commit 1a3e9b77789323426524e454abf717ab89057dca
Author: Cheng Pan <[email protected]>
AuthorDate: Wed Feb 11 14:55:44 2026 +0100

    [SPARK-55411][SQL][4.0] SPJ may throw ArrayIndexOutOfBoundsException when 
join keys are less than cluster keys
    
    Backport https://github.com/apache/spark/issues/54182 to branch-4.0
    
    ### What changes were proposed in this pull request?
    
    Fix a `java.lang.ArrayIndexOutOfBoundsException` when 
`spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true`,
 by correcting the `expression`(should pass the full partition expression 
instead of the projected one) passed to `KeyGroupedPartitioning#project`.
    
    Also, fix a test code issue, change the calculation result of 
`BucketTransform` defined at `InMemoryBaseTable.scala` to match 
`BucketFunctions` defined at `transformFunctions.scala` (thanks peter-toth for 
pointing this out!)
    
    ### Why are the changes needed?
    
    It's a bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Some queries that failed when 
`spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true`
 now run normally.
    
    ### How was this patch tested?
    
    New UT is added, previously it failed with 
`ArrayIndexOutOfBoundsException`, now passed.
    
    ```
    $ build/sbt "sql/testOnly *KeyGroupedPartitioningSuite -- -z SPARK=55411"
    ...
    [info] - bug *** FAILED *** (1 second, 884 milliseconds)
    [info]   java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds 
for length 1
    [info]   at 
scala.collection.immutable.ArraySeq$ofRef.apply(ArraySeq.scala:331)
    [info]   at 
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1(partitioning.scala:471)
    [info]   at 
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1$adapted(partitioning.scala:471)
    [info]   at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
    [info]   at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
    [info]   at 
org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.project(partitioning.scala:471)
    [info]   at 
org.apache.spark.sql.execution.KeyGroupedPartitionedScan.$anonfun$getOutputKeyGroupedPartitioning$5(KeyGroupedPartitionedScan.scala:58)
    ...
    ```
    
    UTs affected by `bucket()` calculate logic change are tuned.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54260 from pan3793/SPARK-55411-4.0.
    
    Authored-by: Cheng Pan <[email protected]>
    Signed-off-by: Peter Toth <[email protected]>
---
 .../sql/connector/catalog/InMemoryBaseTable.scala  | 27 +++++++++-----
 .../execution/datasources/v2/BatchScanExec.scala   |  8 ++---
 .../connector/KeyGroupedPartitioningSuite.scala    | 42 ++++++++++++++++++++--
 .../spark/sql/connector/MetadataColumnSuite.scala  | 16 ++++-----
 .../catalog/functions/transformFunctions.scala     |  3 +-
 5 files changed, 72 insertions(+), 24 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 3ac8c3794b8a..96fccd918855 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -194,15 +194,26 @@ abstract class InMemoryBaseTable(
           case (v, t) =>
             throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
         }
+      // the result should be consistent with BucketFunctions defined at 
transformFunctions.scala
       case BucketTransform(numBuckets, cols, _) =>
-        val valueTypePairs = cols.map(col => extractor(col.fieldNames, 
cleanedSchema, row))
-        var valueHashCode = 0
-        valueTypePairs.foreach( pair =>
-          if ( pair._1 != null) valueHashCode += pair._1.hashCode()
-        )
-        var dataTypeHashCode = 0
-        valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
-        ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % 
numBuckets
+        val hash: Long = cols.foldLeft(0L) { (acc, col) =>
+          val valueHash = extractor(col.fieldNames, cleanedSchema, row) match {
+            case (value: Byte, _: ByteType) => value.toLong
+            case (value: Short, _: ShortType) => value.toLong
+            case (value: Int, _: IntegerType) => value.toLong
+            case (value: Long, _: LongType) => value
+            case (value: Long, _: TimestampType) => value
+            case (value: Long, _: TimestampNTZType) => value
+            case (value: UTF8String, _: StringType) =>
+              value.hashCode.toLong
+            case (value: Array[Byte], BinaryType) =>
+              util.Arrays.hashCode(value).toLong
+            case (v, t) =>
+              throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
+          }
+          acc + valueHash
+        }
+        Math.floorMod(hash, numBuckets)
       case NamedTransform("truncate", Seq(ref: NamedReference, length: 
Literal[_])) =>
         extractor(ref.fieldNames, cleanedSchema, row) match {
           case (str: UTF8String, StringType) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index 6a502a44fad5..82f28bdfbd49 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -119,7 +119,7 @@ case class BatchScanExec(
   override def outputPartitioning: Partitioning = {
     super.outputPartitioning match {
       case k: KeyGroupedPartitioning =>
-        val expressions = spjParams.joinKeyPositions match {
+        val projectedExpressions = spjParams.joinKeyPositions match {
           case Some(projectionPositions) => projectionPositions.map(i => 
k.expressions(i))
           case _ => k.expressions
         }
@@ -134,14 +134,14 @@ case class BatchScanExec(
           case None =>
             spjParams.joinKeyPositions match {
               case Some(projectionPositions) => k.partitionValues.map{r =>
-                val projectedRow = KeyGroupedPartitioning.project(expressions,
+                val projectedRow = 
KeyGroupedPartitioning.project(k.expressions,
                   projectionPositions, r)
-                InternalRowComparableWrapper(projectedRow, expressions)
+                InternalRowComparableWrapper(projectedRow, 
projectedExpressions)
               }.distinct.map(_.row)
               case _ => k.partitionValues
             }
         }
-        k.copy(expressions = expressions, numPartitions = newPartValues.length,
+        k.copy(expressions = projectedExpressions, numPartitions = 
newPartValues.length,
           partitionValues = newPartValues)
       case p => p
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 44882f294491..3a2eb0673039 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -123,7 +123,7 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
 
     // Has exactly one partition.
-    val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
+    val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v)))
     checkQueryPlan(df, distribution,
       physical.KeyGroupedPartitioning(distribution.clustering, 1, 
partitionValues, partitionValues))
   }
@@ -2653,8 +2653,6 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
   }
 
   test("SPARK-54439: KeyGroupedPartitioning with transform and join key size 
mismatch") {
-    // Do not use `bucket()` in "one side partition" tests as its 
implementation in
-    // `InMemoryBaseTable` conflicts with `BucketFunction`
     val items_partitions = Array(years("arrive_time"))
     createTable(items, itemsColumns, items_partitions)
 
@@ -2678,4 +2676,42 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
     }
   }
+
+  test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
+    "are less than cluster keys") {
+    withSQLConf(
+      SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+      SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+      SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+      SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+      SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true") {
+
+      val customers_partitions = Array(identity("customer_name"), bucket(4, 
"customer_id"))
+      createTable(customers, customersColumns, customers_partitions)
+      sql(s"INSERT INTO testcat.ns.$customers VALUES " +
+        s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)")
+
+      createTable(orders, ordersColumns, Array.empty)
+      sql(s"INSERT INTO testcat.ns.$orders VALUES " +
+        s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 
3)")
+
+      val df = sql(
+        s"""${selectWithMergeJoinHint("c", "o")}
+           |customer_name, customer_age, order_amount
+           |FROM testcat.ns.$customers c JOIN testcat.ns.$orders o
+           |ON c.customer_id = o.customer_id ORDER BY c.customer_id, 
order_amount
+           |""".stripMargin)
+
+      val shuffles = collectShuffles(df.queryExecution.executedPlan)
+      assert(shuffles.length == 1)
+
+      checkAnswer(df, Seq(
+        Row("aaa", 10, 100.0),
+        Row("aaa", 10, 200.0),
+        Row("bbb", 20, 150.0),
+        Row("bbb", 20, 250.0),
+        Row("bbb", 20, 350.0),
+        Row("ccc", 30, 400.50)))
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
index 741e30a739f5..7580d524e7ff 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
@@ -41,7 +41,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).select("id", "data", "index", 
"_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -55,7 +55,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).select("index", "data", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, 
"a", "3/1")))
+        checkAnswer(query, Seq(Row(3, "c", "3/3"), Row(2, "b", "2/2"), Row(1, 
"a", "1/1")))
       }
     }
   }
@@ -124,7 +124,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
 
       checkAnswer(
         dfQuery,
-        Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, 
"1/3"))
+        Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, 
"3/3"))
       )
     }
   }
@@ -134,7 +134,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       prepareTable()
       checkAnswer(
         spark.table(tbl).select("id", "data").select("index", "_partition"),
-        Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3"))
+        Seq(Row(0, "1/1"), Row(0, "2/2"), Row(0, "3/3"))
       )
     }
   }
@@ -159,7 +159,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", 
"index", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -171,7 +171,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
       val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", 
"index", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -185,7 +185,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
         .select("id", "data", "index", "_partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
     }
   }
@@ -200,7 +200,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
         s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")
 
       Seq(sqlQuery, dfQuery).foreach { query =>
-        checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), 
Row(3, "c", 0, "1/3")))
+        checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), 
Row(3, "c", 0, "3/3")))
       }
 
       assertThrows[AnalysisException] {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index b82cc2392e1f..ed2f81d7e8d6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -84,6 +84,7 @@ object UnboundBucketFunction extends UnboundFunction {
   override def name(): String = "bucket"
 }
 
+// the result should be consistent with BucketTransform defined at 
InMemoryBaseTable.scala
 object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, 
Int] {
   override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
   override def resultType(): DataType = IntegerType
@@ -91,7 +92,7 @@ object BucketFunction extends ScalarFunction[Int] with 
ReducibleFunction[Int, In
   override def canonicalName(): String = name()
   override def toString: String = name()
   override def produceResult(input: InternalRow): Int = {
-    (input.getLong(1) % input.getInt(0)).toInt
+    Math.floorMod(input.getLong(1), input.getInt(0))
   }
 
   override def reducer(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to