This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new d636fa77c4 [GLUTEN-11088][VL] Fix the Spark4.0 storage partition join 
(#11184)
d636fa77c4 is described below

commit d636fa77c49e991eb02159a0c25431eb499c6da2
Author: Jin Chengcheng <[email protected]>
AuthorDate: Thu Nov 27 11:10:43 2025 +0000

    [GLUTEN-11088][VL] Fix the Spark4.0 storage partition join (#11184)
---
 .../ColumnarShuffleExchangeExecBase.scala          |  23 +-
 .../gluten/utils/velox/VeloxTestSettings.scala     |  24 +-
 .../GlutenKeyGroupedPartitioningSuite.scala        | 866 ++++++++++++++++++++-
 .../gluten/sql/shims/spark40/Spark40Shims.scala    |  26 +-
 4 files changed, 905 insertions(+), 34 deletions(-)

diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
index 44f37e4ffb..17d1ec4038 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
@@ -27,7 +27,7 @@ import org.apache.spark.serializer.Serializer
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
-import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _}
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.execution.exchange._
 import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter
@@ -93,14 +93,21 @@ abstract class ColumnarShuffleExchangeExecBase(
   var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
 
   override protected def doValidateInternal(): ValidationResult = {
-    BackendsApiManager.getValidatorApiInstance
+    val validation = BackendsApiManager.getValidatorApiInstance
       .doColumnarShuffleExchangeExecValidate(output, outputPartitioning, child)
-      .map {
-        reason =>
-          ValidationResult.failed(
-            s"Found schema check failure for schema ${child.schema} due to: 
$reason")
-      }
-      .getOrElse(ValidationResult.succeeded)
+    if (validation.nonEmpty) {
+      return ValidationResult.failed(
+        s"Found schema check failure for schema ${child.schema} due to: 
${validation.get}")
+    }
+    outputPartitioning match {
+      case _: HashPartitioning => ValidationResult.succeeded
+      case _: RangePartitioning => ValidationResult.succeeded
+      case SinglePartition => ValidationResult.succeeded
+      case _: RoundRobinPartitioning => ValidationResult.succeeded
+      case _ =>
+        ValidationResult.failed(
+          s"Unsupported partitioning 
${outputPartitioning.getClass.getSimpleName}")
+    }
   }
 
   override def numMappers: Int = inputColumnarRDD.getNumPartitions
diff --git 
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
 
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index cd505ffa21..07437631f9 100644
--- 
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++ 
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -62,16 +62,20 @@ class VeloxTestSettings extends BackendTestSettings {
   enableSuite[GlutenFileDataSourceV2FallBackSuite]
     // Rewritten
     .exclude("Fallback Parquet V2 to V1")
-  // TODO: fix in Spark-4.0
-  // enableSuite[GlutenKeyGroupedPartitioningSuite]
-  //   // NEW SUITE: disable as they check vanilla spark plan
-  //   .exclude("partitioned join: number of buckets mismatch should trigger 
shuffle")
-  //   .exclude("partitioned join: only one side reports partitioning")
-  //   .exclude("partitioned join: join with two partition keys and different 
# of partition keys")
-  //   // disable due to check for SMJ node
-  //   .excludeByPrefix("SPARK-41413: partitioned join:")
-  //   .excludeByPrefix("SPARK-42038: partially clustered:")
-  //   .exclude("SPARK-44641: duplicated records when SPJ is not triggered")
+  enableSuite[GlutenKeyGroupedPartitioningSuite]
+    // NEW SUITE: disable as they check vanilla spark plan
+    .exclude("partitioned join: number of buckets mismatch should trigger 
shuffle")
+    .exclude("partitioned join: only one side reports partitioning")
+    .exclude("partitioned join: join with two partition keys and different # 
of partition keys")
+    .excludeByPrefix("SPARK-47094")
+    .excludeByPrefix("SPARK-48655")
+    .excludeByPrefix("SPARK-48012")
+    .excludeByPrefix("SPARK-44647")
+    .excludeByPrefix("SPARK-41471")
+    // disable due to check for SMJ node
+    .excludeByPrefix("SPARK-41413: partitioned join:")
+    .excludeByPrefix("SPARK-42038: partially clustered:")
+    .exclude("SPARK-44641: duplicated records when SPJ is not triggered")
   enableSuite[GlutenLocalScanSuite]
   enableSuite[GlutenMetadataColumnSuite]
   enableSuite[GlutenSupportsCatalogOptionsSuite]
diff --git 
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
 
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
index ef87b50400..00d370b39f 100644
--- 
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
+++ 
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
@@ -20,13 +20,14 @@ import org.apache.gluten.config.GlutenConfig
 import org.apache.gluten.execution.SortMergeJoinExecTransformer
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.{GlutenSQLTestsBaseTrait, Row}
+import org.apache.spark.sql.{DataFrame, GlutenSQLTestsBaseTrait, Row}
 import org.apache.spark.sql.connector.catalog.{Column, Identifier, 
InMemoryTableCatalog}
 import org.apache.spark.sql.connector.distributions.Distributions
-import org.apache.spark.sql.connector.expressions.Expressions.{bucket, days, 
identity}
+import org.apache.spark.sql.connector.expressions.Expressions.{bucket, days, 
identity, years}
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, 
ShuffleExchangeLike}
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -47,6 +48,17 @@ class GlutenKeyGroupedPartitioningSuite
   private val emptyProps: java.util.Map[String, String] = {
     Collections.emptyMap[String, String]
   }
+
+  private val columns: Array[Column] = Array(
+    Column.create("id", IntegerType),
+    Column.create("data", StringType),
+    Column.create("ts", TimestampType))
+
+  private val columns2: Array[Column] = Array(
+    Column.create("store_id", IntegerType),
+    Column.create("dept_id", IntegerType),
+    Column.create("data", StringType))
+
   private def createTable(
       table: String,
       columns: Array[Column],
@@ -72,10 +84,46 @@ class GlutenKeyGroupedPartitioningSuite
       case s: SortMergeJoinExec => s
     }.flatMap(smj => collect(smj) { case s: ColumnarShuffleExchangeExec => s })
   }
+
+  private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = {
+    // here we skip collecting shuffle operators that are not associated with 
SMJ
+    collect(plan) {
+      case s: SortMergeJoinExec => s
+      case s: SortMergeJoinExecTransformer => s
+    }.flatMap(
+      smj =>
+        collect(smj) {
+          case s: ShuffleExchangeExec => s
+          case s: ColumnarShuffleExchangeExec => s
+        })
+  }
+
+  private def collectAllShuffles(plan: SparkPlan): 
Seq[ColumnarShuffleExchangeExec] = {
+    collect(plan) { case s: ColumnarShuffleExchangeExec => s }
+  }
+
   private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = {
     collect(plan) { case s: BatchScanExec => s }
   }
 
+  private def selectWithMergeJoinHint(t1: String, t2: String): String = {
+    s"SELECT /*+ MERGE($t1, $t2) */ "
+  }
+
+  private def createJoinTestDF(
+      keys: Seq[(String, String)],
+      extraColumns: Seq[String] = Nil,
+      joinType: String = ""): DataFrame = {
+    val extraColList = if (extraColumns.isEmpty) "" else 
extraColumns.mkString(", ", ", ", "")
+    sql(s"""
+           |${selectWithMergeJoinHint("i", "p")}
+           |id, name, i.price as purchase_price, p.price as sale_price 
$extraColList
+           |FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p
+           |ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")}
+           |ORDER BY id, purchase_price, sale_price $extraColList
+           |""".stripMargin)
+  }
+
   private val customers: String = "customers"
   private val customersColumns: Array[Column] = Array(
     Column.create("customer_name", StringType),
@@ -912,13 +960,23 @@ class GlutenKeyGroupedPartitioningSuite
         s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
         s"(2, 11.0, cast('2020-01-01' as timestamp))")
 
-    val df = sql(
-      "SELECT id, name, i.price as purchase_price, p.price as sale_price " +
-        s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
-        "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, 
purchase_price, sale_price")
+    Seq(true, false).foreach {
+      pushDownValues =>
+        withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushDownValues.toString) {
+          val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> 
"time"))
+          val shuffles = 
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+          if (pushDownValues) {
+            assert(shuffles.isEmpty, "should not add shuffle when partition 
values mismatch")
+          } else {
+            assert(
+              shuffles.nonEmpty,
+              "should add shuffle when partition values mismatch, and " +
+                "pushing down partition values is not enabled")
+          }
 
-    val shuffles = 
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
-    assert(shuffles.nonEmpty, "should add shuffle when partition keys 
mismatch")
+          checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(2, "bb", 10.0, 
11.0)))
+        }
+    }
   }
 
   testGluten("data source partitioning + dynamic partition filtering") {
@@ -972,4 +1030,796 @@ class GlutenKeyGroupedPartitioningSuite
       }
     }
   }
+
+  testGluten(
+    "SPARK-41471: shuffle one side: only one side reports partitioning with 
two identity") {
+    val items_partitions = Array(identity("id"), identity("arrive_time"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      shuffle =>
+        withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+          val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> 
"time"))
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          if (shuffle) {
+            assert(shuffles.size == 1, "only shuffle one side not report 
partitioning")
+          } else {
+            assert(
+              shuffles.size == 2,
+              "should add two side shuffle when bucketing shuffle one side" +
+                " is not enabled")
+          }
+
+          checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
+        }
+    }
+  }
+
+  testGluten("SPARK-41471: shuffle one side: only one side reports 
partitioning") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      shuffle =>
+        withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+          val df = createJoinTestDF(Seq("id" -> "item_id"))
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          if (shuffle) {
+            assert(shuffles.size == 1, "only shuffle one side not report 
partitioning")
+          } else {
+            assert(
+              shuffles.size == 2,
+              "should add two side shuffle when bucketing shuffle one side" +
+                " is not enabled")
+          }
+
+          checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 
19.5)))
+        }
+    }
+  }
+
+  testGluten("SPARK-41471: shuffle one side: shuffle side has more partition 
value") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+        "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+        "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      shuffle =>
+        withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+          Seq("", "LEFT OUTER", "RIGHT OUTER", "FULL OUTER").foreach {
+            joinType =>
+              val df = createJoinTestDF(Seq("id" -> "item_id"), joinType = 
joinType)
+              val shuffles = collectShuffles(df.queryExecution.executedPlan)
+              if (shuffle) {
+                assert(shuffles.size == 1, "only shuffle one side not report 
partitioning")
+              } else {
+                assert(
+                  shuffles.size == 2,
+                  "should add two side shuffle when bucketing shuffle one " +
+                    "side is not enabled")
+              }
+              joinType match {
+                case "" =>
+                  checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 
10.0, 19.5)))
+                case "LEFT OUTER" =>
+                  checkAnswer(
+                    df,
+                    Seq(
+                      Row(1, "aa", 40.0, 42.0),
+                      Row(3, "bb", 10.0, 19.5),
+                      Row(4, "cc", 15.5, null)))
+                case "RIGHT OUTER" =>
+                  checkAnswer(
+                    df,
+                    Seq(
+                      Row(null, null, null, 26.0),
+                      Row(null, null, null, 50.0),
+                      Row(1, "aa", 40.0, 42.0),
+                      Row(3, "bb", 10.0, 19.5)))
+                case "FULL OUTER" =>
+                  checkAnswer(
+                    df,
+                    Seq(
+                      Row(null, null, null, 26.0),
+                      Row(null, null, null, 50.0),
+                      Row(1, "aa", 40.0, 42.0),
+                      Row(3, "bb", 10.0, 19.5),
+                      Row(4, "cc", 15.5, null)))
+              }
+          }
+        }
+    }
+  }
+
+  testGluten("SPARK-41471: shuffle one side: partitioning with transform") {
+    val items_partitions = Array(years("arrive_time"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(3, 19.5, cast('2021-02-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      shuffle =>
+        withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+          val df = createJoinTestDF(Seq("arrive_time" -> "time"))
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          if (shuffle) {
+            assert(shuffles.size == 1, "partitioning with transform should 
trigger SPJ")
+          } else {
+            assert(
+              shuffles.size == 2,
+              "should add two side shuffle when bucketing shuffle one side" +
+                " is not enabled")
+          }
+
+          checkAnswer(
+            df,
+            Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 42.0), Row(4, 
"cc", 15.5, 19.5)))
+        }
+    }
+  }
+
+  testGluten(
+    "SPARK-44647: SPJ: test join key is subset of cluster key " +
+      "with push values and partially-clustered") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+    val partition = Array(identity("id"), identity("data"))
+    createTable(table1, columns, partition)
+    sql(
+      s"INSERT INTO testcat.ns.$table1 VALUES " +
+        "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+        "(2, 'bb', cast('2020-01-01' as timestamp)), " +
+        "(2, 'cc', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'ee', cast('2020-01-01' as timestamp)), " +
+        "(3, 'ee', cast('2020-01-01' as timestamp))")
+
+    createTable(table2, columns, partition)
+    sql(
+      s"INSERT INTO testcat.ns.$table2 VALUES " +
+        "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+        "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+        "(3, 'yy', cast('2020-01-01' as timestamp)), " +
+        "(3, 'yy', cast('2020-01-01' as timestamp)), " +
+        "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+        "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+        "(2, 'ww', cast('2020-01-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      pushDownValues =>
+        Seq(true, false).foreach {
+          filter =>
+            Seq(true, false).foreach {
+              partiallyClustered =>
+                Seq(true, false).foreach {
+                  allowJoinKeysSubsetOfPartitionKeys =>
+                    withSQLConf(
+                      SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> 
"false",
+                      SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushDownValues.toString,
+                      
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+                        partiallyClustered.toString,
+                      SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> 
filter.toString,
+                      
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+                        allowJoinKeysSubsetOfPartitionKeys.toString
+                    ) {
+                      val df = sql(s"""
+                                      |${selectWithMergeJoinHint("t1", "t2")}
+                                      |t1.id AS id, t1.data AS t1data, t2.data 
AS t2data
+                                      |FROM testcat.ns.$table1 t1 JOIN 
testcat.ns.$table2 t2
+                                      |ON t1.id = t2.id ORDER BY t1.id, 
t1data, t2data
+                                      |""".stripMargin)
+                      val shuffles =
+                        
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+                      if (allowJoinKeysSubsetOfPartitionKeys) {
+                        assert(shuffles.isEmpty, "SPJ should be triggered")
+                      } else {
+                        assert(shuffles.nonEmpty, "SPJ should not be 
triggered")
+                      }
+
+                      val scannedPartitions = 
collectScans(df.queryExecution.executedPlan)
+                        .map(_.inputRDD.partitions.length)
+                      (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered, 
filter) match {
+                        // SPJ, partially-clustered, with filter
+                        case (true, true, true) => assert(scannedPartitions == 
Seq(6, 6))
+
+                        // SPJ, partially-clustered, no filter
+                        case (true, true, false) => assert(scannedPartitions 
== Seq(8, 8))
+
+                        // SPJ and not partially-clustered, with filter
+                        case (true, false, true) => assert(scannedPartitions 
== Seq(2, 2))
+
+                        // SPJ and not partially-clustered, no filter
+                        case (true, false, false) => assert(scannedPartitions 
== Seq(4, 4))
+
+                        // No SPJ
+                        case _ => assert(scannedPartitions == Seq(5, 4))
+                      }
+
+                      checkAnswer(
+                        df,
+                        Seq(
+                          Row(2, "bb", "ww"),
+                          Row(2, "cc", "ww"),
+                          Row(3, "dd", "xx"),
+                          Row(3, "dd", "xx"),
+                          Row(3, "dd", "xx"),
+                          Row(3, "dd", "xx"),
+                          Row(3, "dd", "yy"),
+                          Row(3, "dd", "yy"),
+                          Row(3, "dd", "yy"),
+                          Row(3, "dd", "yy"),
+                          Row(3, "ee", "xx"),
+                          Row(3, "ee", "xx"),
+                          Row(3, "ee", "xx"),
+                          Row(3, "ee", "xx"),
+                          Row(3, "ee", "yy"),
+                          Row(3, "ee", "yy"),
+                          Row(3, "ee", "yy"),
+                          Row(3, "ee", "yy")
+                        )
+                      )
+                    }
+                }
+            }
+        }
+    }
+  }
+
+  testGluten("SPARK-44647: test join key is the second cluster key") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+    val partition = Array(identity("id"), identity("data"))
+    createTable(table1, columns, partition)
+    sql(
+      s"INSERT INTO testcat.ns.$table1 VALUES " +
+        "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+        "(2, 'bb', cast('2020-01-02' as timestamp)), " +
+        "(3, 'cc', cast('2020-01-03' as timestamp))")
+
+    createTable(table2, columns, partition)
+    sql(
+      s"INSERT INTO testcat.ns.$table2 VALUES " +
+        "(4, 'aa', cast('2020-01-01' as timestamp)), " +
+        "(5, 'bb', cast('2020-01-02' as timestamp)), " +
+        "(6, 'cc', cast('2020-01-03' as timestamp))")
+
+    Seq(true, false).foreach {
+      pushDownValues =>
+        Seq(true, false).foreach {
+          partiallyClustered =>
+            Seq(true, false).foreach {
+              allowJoinKeysSubsetOfPartitionKeys =>
+                withSQLConf(
+                  SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> 
"false",
+                  SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
+                    pushDownValues.toString,
+                  
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+                    partiallyClustered.toString,
+                  
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+                    allowJoinKeysSubsetOfPartitionKeys.toString
+                ) {
+
+                  val df = sql(s"""
+                                  |${selectWithMergeJoinHint("t1", "t2")}
+                                  |t1.id AS t1id, t2.id as t2id, t1.data AS 
data
+                                  |FROM testcat.ns.$table1 t1 JOIN 
testcat.ns.$table2 t2
+                                  |ON t1.data = t2.data
+                                  |ORDER BY t1id, t1id, data
+                                  |""".stripMargin)
+                  checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3, 
6, "cc")))
+
+                  val shuffles = 
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+                  if (allowJoinKeysSubsetOfPartitionKeys) {
+                    assert(shuffles.isEmpty, "SPJ should be triggered")
+                  } else {
+                    assert(shuffles.nonEmpty, "SPJ should not be triggered")
+                  }
+
+                  val scans = collectScans(df.queryExecution.executedPlan)
+                    .map(_.inputRDD.partitions.length)
+                  (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, 
partiallyClustered) match {
+                    // SPJ and partially-clustered
+                    case (true, true, true) => assert(scans == Seq(3, 3))
+                    // non-SPJ or SPJ/partially-clustered
+                    case _ => assert(scans == Seq(3, 3))
+                  }
+                }
+            }
+        }
+    }
+  }
+
+  testGluten("SPARK-44647: test join key is the second partition key and a 
transform") {
+    val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
+    createTable(items, itemsColumns, items_partitions)
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    val purchases_partitions = Array(bucket(8, "item_id"), days("time"))
+    createTable(purchases, purchasesColumns, purchases_partitions)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 44.0, cast('2020-01-15' as timestamp)), " +
+        s"(1, 45.0, cast('2020-01-15' as timestamp)), " +
+        s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 19.5, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      pushDownValues =>
+        Seq(true, false).foreach {
+          partiallyClustered =>
+            Seq(true, false).foreach {
+              allowJoinKeysSubsetOfPartitionKeys =>
+                withSQLConf(
+                  SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> 
"false",
+                  SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushDownValues.toString,
+                  
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+                    partiallyClustered.toString,
+                  
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+                    allowJoinKeysSubsetOfPartitionKeys.toString
+                ) {
+                  val df =
+                    createJoinTestDF(Seq("arrive_time" -> "time"), 
extraColumns = Seq("p.item_id"))
+                  // Currently SPJ for case where join key not same as 
partition key
+                  // only supported when push-part-values enabled
+                  val shuffles = 
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+                  if (allowJoinKeysSubsetOfPartitionKeys) {
+                    assert(shuffles.isEmpty, "SPJ should be triggered")
+                  } else {
+                    assert(shuffles.nonEmpty, "SPJ should not be triggered")
+                  }
+
+                  val scans = collectScans(df.queryExecution.executedPlan)
+                    .map(_.inputRDD.partitions.length)
+                  (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) 
match {
+                    // SPJ and partially-clustered
+                    case (true, true) => assert(scans == Seq(5, 5))
+                    // SPJ and not partially-clustered
+                    case (true, false) => assert(scans == Seq(3, 3))
+                    // No SPJ
+                    case _ => assert(scans == Seq(4, 4))
+                  }
+
+                  checkAnswer(
+                    df,
+                    Seq(
+                      Row(1, "aa", 40.0, 11.0, 2),
+                      Row(1, "aa", 40.0, 42.0, 1),
+                      Row(1, "aa", 41.0, 44.0, 1),
+                      Row(1, "aa", 41.0, 45.0, 1),
+                      Row(2, "bb", 10.0, 11.0, 2),
+                      Row(2, "bb", 10.0, 42.0, 1),
+                      Row(2, "bb", 10.5, 11.0, 2),
+                      Row(2, "bb", 10.5, 42.0, 1),
+                      Row(3, "cc", 15.5, 19.5, 3)
+                    )
+                  )
+                }
+            }
+        }
+    }
+  }
+
+  testGluten("SPARK-44647: shuffle one side and join keys are less than 
partition keys") {
+    val items_partitions = Array(identity("id"), identity("name"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
+        "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 89.0, cast('2020-01-03' as timestamp)), " +
+        "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+        "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+        "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      pushdownValues =>
+        withSQLConf(
+          SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushdownValues.toString,
+          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+          SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true"
+        ) {
+          val df = createJoinTestDF(Seq("id" -> "item_id"))
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.size == 1, "SPJ should be triggered")
+          checkAnswer(
+            df,
+            Seq(
+              Row(1, "aa", 30.0, 42.0),
+              Row(1, "aa", 30.0, 89.0),
+              Row(1, "aa", 40.0, 42.0),
+              Row(1, "aa", 40.0, 89.0),
+              Row(3, "bb", 10.0, 19.5)))
+        }
+    }
+  }
+
+  testGluten(
+    "SPARK-47094: Compatible buckets does not support SPJ with " +
+      "push-down values or partially-clustered") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    val partition1 = Array(bucket(4, "store_id"), bucket(2, "dept_id"))
+    val partition2 = Array(bucket(2, "store_id"), bucket(2, "dept_id"))
+
+    createTable(table1, columns2, partition1)
+    sql(
+      s"INSERT INTO testcat.ns.$table1 VALUES " +
+        "(0, 0, 'aa'), " +
+        "(1, 1, 'bb'), " +
+        "(2, 2, 'cc')"
+    )
+
+    createTable(table2, columns2, partition2)
+    sql(
+      s"INSERT INTO testcat.ns.$table2 VALUES " +
+        "(0, 0, 'aa'), " +
+        "(1, 1, 'bb'), " +
+        "(2, 2, 'cc')"
+    )
+
+    Seq(true, false).foreach {
+      allowPushDown =>
+        Seq(true, false).foreach {
+          partiallyClustered =>
+            withSQLConf(
+              SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
allowPushDown.toString,
+              
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+                partiallyClustered.toString,
+              
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true",
+              SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true"
+            ) {
+              val df = sql(s"""
+                              |${selectWithMergeJoinHint("t1", "t2")}
+                              |t1.store_id, t1.store_id, t1.dept_id, 
t2.dept_id, t1.data, t2.data
+                              |FROM testcat.ns.$table1 t1 JOIN 
testcat.ns.$table2 t2
+                              |ON t1.store_id = t2.store_id AND t1.dept_id = 
t2.dept_id
+                              |ORDER BY t1.store_id, t1.dept_id, t1.data, 
t2.data
+                              |""".stripMargin)
+
+              val shuffles = 
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+              val scans =
+                
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.partitions.length)
+
+              (allowPushDown, partiallyClustered) match {
+                case (true, false) =>
+                  assert(shuffles.isEmpty, "SPJ should be triggered")
+                  assert(scans == Seq(2, 2))
+                case (_, _) =>
+                  assert(shuffles.nonEmpty, "SPJ should not be triggered")
+                  assert(scans == Seq(3, 2))
+              }
+
+              checkAnswer(
+                df,
+                Seq(
+                  Row(0, 0, 0, 0, "aa", "aa"),
+                  Row(1, 1, 1, 1, "bb", "bb"),
+                  Row(2, 2, 2, 2, "cc", "cc")
+                ))
+            }
+        }
+    }
+  }
+
+  testGluten(
+    "SPARK-47094: SPJ: Does not trigger when incompatible number of buckets on 
both side") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    Seq(
+      (2, 3),
+      (3, 4)
+    ).foreach {
+      case (table1buckets1, table2buckets1) =>
+        catalog.clearTables()
+
+        val partition1 = Array(bucket(table1buckets1, "store_id"))
+        val partition2 = Array(bucket(table2buckets1, "store_id"))
+
+        Seq((table1, partition1), (table2, partition2)).foreach {
+          case (tab, part) =>
+            createTable(tab, columns2, part)
+            val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
+              "(0, 0, 'aa'), " +
+              "(1, 0, 'ab'), " + // duplicate partition key
+              "(2, 2, 'ac'), " +
+              "(3, 3, 'ad'), " +
+              "(4, 2, 'bc') "
+
+            sql(insertStr)
+        }
+
+        Seq(true, false).foreach {
+          allowJoinKeysSubsetOfPartitionKeys =>
+            withSQLConf(
+              SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+              
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
+              
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+                allowJoinKeysSubsetOfPartitionKeys.toString,
+              SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true"
+            ) {
+              val df = sql(s"""
+                              |${selectWithMergeJoinHint("t1", "t2")}
+                              |t1.store_id, t1.dept_id, t1.data, t2.data
+                              |FROM testcat.ns.$table1 t1 JOIN 
testcat.ns.$table2 t2
+                              |ON t1.store_id = t2.store_id AND t1.dept_id = 
t2.dept_id
+                              |""".stripMargin)
+
+              val shuffles = 
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+              assert(shuffles.nonEmpty, "SPJ should not be triggered")
+            }
+        }
+    }
+  }
+
+  testGluten("SPARK-48655: order by on partition keys should not introduce 
additional shuffle") {
+    val items_partitions = Array(identity("price"), identity("id"))
+    createTable(items, itemsColumns, items_partitions)
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+        s"(null, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+        s"(3, 'cc', null, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      sortingEnabled =>
+        withSQLConf(SQLConf.V2_BUCKETING_SORTING_ENABLED.key -> 
sortingEnabled.toString) {
+
+          def verifyShuffle(cmd: String, answer: Seq[Row]): Unit = {
+            val df = sql(cmd)
+            if (sortingEnabled) {
+              assert(
+                collectAllShuffles(df.queryExecution.executedPlan).isEmpty,
+                "should contain no shuffle when sorting by partition values")
+            } else {
+              assert(
+                collectAllShuffles(df.queryExecution.executedPlan).size == 1,
+                "should contain one shuffle when optimization is disabled")
+            }
+            checkAnswer(df, answer)
+          }: Unit
+
+          verifyShuffle(
+            s"SELECT price, id FROM testcat.ns.$items ORDER BY price ASC, id 
ASC",
+            Seq(
+              Row(null, 3),
+              Row(10.0, 2),
+              Row(15.5, null),
+              Row(15.5, 3),
+              Row(40.0, 1),
+              Row(41.0, 1)))
+
+          verifyShuffle(
+            s"SELECT price, id FROM testcat.ns.$items " +
+              s"ORDER BY price ASC NULLS LAST, id ASC NULLS LAST",
+            Seq(
+              Row(10.0, 2),
+              Row(15.5, 3),
+              Row(15.5, null),
+              Row(40.0, 1),
+              Row(41.0, 1),
+              Row(null, 3))
+          )
+
+          verifyShuffle(
+            s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id 
ASC",
+            Seq(
+              Row(41.0, 1),
+              Row(40.0, 1),
+              Row(15.5, null),
+              Row(15.5, 3),
+              Row(10.0, 2),
+              Row(null, 3))
+          )
+
+          verifyShuffle(
+            s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id 
DESC",
+            Seq(
+              Row(41.0, 1),
+              Row(40.0, 1),
+              Row(15.5, 3),
+              Row(15.5, null),
+              Row(10.0, 2),
+              Row(null, 3))
+          )
+
+          verifyShuffle(
+            s"SELECT price, id FROM testcat.ns.$items " +
+              s"ORDER BY price DESC NULLS FIRST, id DESC NULLS FIRST",
+            Seq(
+              Row(null, 3),
+              Row(41.0, 1),
+              Row(40.0, 1),
+              Row(15.5, null),
+              Row(15.5, 3),
+              Row(10.0, 2))
+          );
+        }
+    }
+  }
+
+  testGluten("SPARK-48012: one-side shuffle with partition transforms") {
+    val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
+    val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id"))
+
+    Seq(items_partitions, items_partitions2).foreach {
+      partition =>
+        catalog.clearTables()
+
+        createTable(items, itemsColumns, partition)
+        sql(
+          s"INSERT INTO testcat.ns.$items VALUES " +
+            "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+            "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
+            "(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " +
+            "(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " +
+            "(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " +
+            "(5, 'ff', 32.1, cast('2020-03-01' as timestamp))")
+
+        createTable(purchases, purchasesColumns, Array.empty)
+        sql(
+          s"INSERT INTO testcat.ns.$purchases VALUES " +
+            "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+            "(2, 10.7, cast('2020-01-01' as timestamp))," +
+            "(3, 19.5, cast('2020-02-01' as timestamp))," +
+            "(4, 56.5, cast('2020-02-01' as timestamp))")
+
+        withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
+          val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> 
"time"))
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.size == 1, "only shuffle side that does not report 
partitioning")
+
+          checkAnswer(
+            df,
+            Seq(Row(1, "bb", 30.0, 42.0), Row(1, "aa", 40.0, 42.0), Row(4, 
"ee", 15.5, 56.5)))
+        }
+    }
+  }
+
+  testGluten("SPARK-48012: one-side shuffle with partition transforms and 
pushdown values") {
+    val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 'cc', 30.0, cast('2020-01-02' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(2, 10.7, cast('2020-01-01' as timestamp))")
+
+    Seq(true, false).foreach {
+      pushDown =>
+        {
+          withSQLConf(
+            SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
+              pushDown.toString) {
+            val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> 
"time"))
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            assert(shuffles.size == 1, "only shuffle side that does not report 
partitioning")
+
+            checkAnswer(df, Seq(Row(1, "bb", 30.0, 42.0), Row(1, "aa", 40.0, 
42.0)))
+          }
+        }
+    }
+  }
+
+  testGluten(
+    "SPARK-48012: one-side shuffle with partition transforms " +
+      "with fewer join keys than partition kes") {
+    val items_partitions = Array(bucket(2, "id"), identity("name"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(
+      s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
+        "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(
+      s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 89.0, cast('2020-01-03' as timestamp)), " +
+        "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+        "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+        "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+    withSQLConf(
+      SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+      SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+      SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+      SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+      SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true"
+    ) {
+      val df = createJoinTestDF(Seq("id" -> "item_id"))
+      val shuffles = collectShuffles(df.queryExecution.executedPlan)
+      assert(shuffles.size == 1, "SPJ should be triggered")
+      checkAnswer(
+        df,
+        Seq(
+          Row(1, "aa", 30.0, 42.0),
+          Row(1, "aa", 30.0, 89.0),
+          Row(1, "aa", 40.0, 42.0),
+          Row(1, "aa", 40.0, 89.0),
+          Row(3, "bb", 10.0, 19.5)))
+    }
+  }
+
 }
diff --git 
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
 
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
index 9077fe5abc..247394ba91 100644
--- 
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
+++ 
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, KeyGroupedPartitioning, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
InternalRowComparableWrapper, TimestampFormatter}
@@ -47,7 +47,7 @@ import org.apache.spark.sql.connector.read.{HasPartitionKey, 
InputPartition, Sca
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, 
ParquetFilters}
-import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, 
DataSourceV2ScanExecBase}
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, 
BatchScanExecShim, DataSourceV2ScanExecBase}
 import org.apache.spark.sql.execution.datasources.v2.text.TextScan
 import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, 
ShuffleExchangeLike}
@@ -482,10 +482,9 @@ class Spark40Shims extends SparkShims {
       applyPartialClustering: Boolean,
       replicatePartitions: Boolean,
       joinKeyPositions: Option[Seq[Int]] = None): Seq[Seq[InputPartition]] = {
+    val original = batchScan.asInstanceOf[BatchScanExecShim]
     scan match {
       case _ if keyGroupedPartitioning.isDefined =>
-        var finalPartitions = filteredPartitions
-
         outputPartitioning match {
           case p: KeyGroupedPartitioning =>
             assert(keyGroupedPartitioning.isDefined)
@@ -516,8 +515,20 @@ class Spark40Shims extends SparkShims {
             }
 
             // Also re-group the partitions if we are reducing compatible 
partition expressions
-            // TODO: Respect Reducer settings?
-            val finalGroupedPartitions = groupedPartitions
+            val finalGroupedPartitions = original.reducers match {
+              case Some(reducers) =>
+                val result = groupedPartitions
+                  .groupBy {
+                    case (row, _) =>
+                      KeyGroupedShuffleSpec.reducePartitionValue(row, 
partExpressions, reducers)
+                  }
+                  .map { case (wrapper, splits) => (wrapper.row, 
splits.flatMap(_._2)) }
+                  .toSeq
+                val rowOrdering =
+                  
RowOrdering.createNaturalAscendingOrdering(partExpressions.map(_.dataType))
+                result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
+              case _ => groupedPartitions
+            }
 
             // When partially clustered, the input partitions are not grouped 
by partition
             // values. Here we'll need to check `commonPartitionValues` and 
decide how to group
@@ -587,9 +598,8 @@ class Spark40Shims extends SparkShims {
               }
             }
 
-          case _ =>
+          case _ => filteredPartitions
         }
-        finalPartitions
       case _ =>
         filteredPartitions
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to