This is an automated email from the ASF dual-hosted git repository.
chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d636fa77c4 [GLUTEN-11088][VL] Fix the Spark4.0 storage partition join
(#11184)
d636fa77c4 is described below
commit d636fa77c49e991eb02159a0c25431eb499c6da2
Author: Jin Chengcheng <[email protected]>
AuthorDate: Thu Nov 27 11:10:43 2025 +0000
[GLUTEN-11088][VL] Fix the Spark4.0 storage partition join (#11184)
---
.../ColumnarShuffleExchangeExecBase.scala | 23 +-
.../gluten/utils/velox/VeloxTestSettings.scala | 24 +-
.../GlutenKeyGroupedPartitioningSuite.scala | 866 ++++++++++++++++++++-
.../gluten/sql/shims/spark40/Spark40Shims.scala | 26 +-
4 files changed, 905 insertions(+), 34 deletions(-)
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
index 44f37e4ffb..17d1ec4038 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala
@@ -27,7 +27,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
-import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter
@@ -93,14 +93,21 @@ abstract class ColumnarShuffleExchangeExecBase(
var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
override protected def doValidateInternal(): ValidationResult = {
- BackendsApiManager.getValidatorApiInstance
+ val validation = BackendsApiManager.getValidatorApiInstance
.doColumnarShuffleExchangeExecValidate(output, outputPartitioning, child)
- .map {
- reason =>
- ValidationResult.failed(
- s"Found schema check failure for schema ${child.schema} due to:
$reason")
- }
- .getOrElse(ValidationResult.succeeded)
+ if (validation.nonEmpty) {
+ return ValidationResult.failed(
+ s"Found schema check failure for schema ${child.schema} due to:
${validation.get}")
+ }
+ outputPartitioning match {
+ case _: HashPartitioning => ValidationResult.succeeded
+ case _: RangePartitioning => ValidationResult.succeeded
+ case SinglePartition => ValidationResult.succeeded
+ case _: RoundRobinPartitioning => ValidationResult.succeeded
+ case _ =>
+ ValidationResult.failed(
+ s"Unsupported partitioning
${outputPartitioning.getClass.getSimpleName}")
+ }
}
override def numMappers: Int = inputColumnarRDD.getNumPartitions
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index cd505ffa21..07437631f9 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -62,16 +62,20 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenFileDataSourceV2FallBackSuite]
// Rewritten
.exclude("Fallback Parquet V2 to V1")
- // TODO: fix in Spark-4.0
- // enableSuite[GlutenKeyGroupedPartitioningSuite]
- // // NEW SUITE: disable as they check vanilla spark plan
- // .exclude("partitioned join: number of buckets mismatch should trigger
shuffle")
- // .exclude("partitioned join: only one side reports partitioning")
- // .exclude("partitioned join: join with two partition keys and different
# of partition keys")
- // // disable due to check for SMJ node
- // .excludeByPrefix("SPARK-41413: partitioned join:")
- // .excludeByPrefix("SPARK-42038: partially clustered:")
- // .exclude("SPARK-44641: duplicated records when SPJ is not triggered")
+ enableSuite[GlutenKeyGroupedPartitioningSuite]
+ // NEW SUITE: disable as they check vanilla spark plan
+ .exclude("partitioned join: number of buckets mismatch should trigger
shuffle")
+ .exclude("partitioned join: only one side reports partitioning")
+ .exclude("partitioned join: join with two partition keys and different #
of partition keys")
+ .excludeByPrefix("SPARK-47094")
+ .excludeByPrefix("SPARK-48655")
+ .excludeByPrefix("SPARK-48012")
+ .excludeByPrefix("SPARK-44647")
+ .excludeByPrefix("SPARK-41471")
+ // disable due to check for SMJ node
+ .excludeByPrefix("SPARK-41413: partitioned join:")
+ .excludeByPrefix("SPARK-42038: partially clustered:")
+ .exclude("SPARK-44641: duplicated records when SPJ is not triggered")
enableSuite[GlutenLocalScanSuite]
enableSuite[GlutenMetadataColumnSuite]
enableSuite[GlutenSupportsCatalogOptionsSuite]
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
index ef87b50400..00d370b39f 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
@@ -20,13 +20,14 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.SortMergeJoinExecTransformer
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{GlutenSQLTestsBaseTrait, Row}
+import org.apache.spark.sql.{DataFrame, GlutenSQLTestsBaseTrait, Row}
import org.apache.spark.sql.connector.catalog.{Column, Identifier,
InMemoryTableCatalog}
import org.apache.spark.sql.connector.distributions.Distributions
-import org.apache.spark.sql.connector.expressions.Expressions.{bucket, days,
identity}
+import org.apache.spark.sql.connector.expressions.Expressions.{bucket, days,
identity, years}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec,
ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -47,6 +48,17 @@ class GlutenKeyGroupedPartitioningSuite
private val emptyProps: java.util.Map[String, String] = {
Collections.emptyMap[String, String]
}
+
+ private val columns: Array[Column] = Array(
+ Column.create("id", IntegerType),
+ Column.create("data", StringType),
+ Column.create("ts", TimestampType))
+
+ private val columns2: Array[Column] = Array(
+ Column.create("store_id", IntegerType),
+ Column.create("dept_id", IntegerType),
+ Column.create("data", StringType))
+
private def createTable(
table: String,
columns: Array[Column],
@@ -72,10 +84,46 @@ class GlutenKeyGroupedPartitioningSuite
case s: SortMergeJoinExec => s
}.flatMap(smj => collect(smj) { case s: ColumnarShuffleExchangeExec => s })
}
+
+ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = {
+ // here we skip collecting shuffle operators that are not associated with
SMJ
+ collect(plan) {
+ case s: SortMergeJoinExec => s
+ case s: SortMergeJoinExecTransformer => s
+ }.flatMap(
+ smj =>
+ collect(smj) {
+ case s: ShuffleExchangeExec => s
+ case s: ColumnarShuffleExchangeExec => s
+ })
+ }
+
+ private def collectAllShuffles(plan: SparkPlan):
Seq[ColumnarShuffleExchangeExec] = {
+ collect(plan) { case s: ColumnarShuffleExchangeExec => s }
+ }
+
private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = {
collect(plan) { case s: BatchScanExec => s }
}
+ private def selectWithMergeJoinHint(t1: String, t2: String): String = {
+ s"SELECT /*+ MERGE($t1, $t2) */ "
+ }
+
+ private def createJoinTestDF(
+ keys: Seq[(String, String)],
+ extraColumns: Seq[String] = Nil,
+ joinType: String = ""): DataFrame = {
+ val extraColList = if (extraColumns.isEmpty) "" else
extraColumns.mkString(", ", ", ", "")
+ sql(s"""
+ |${selectWithMergeJoinHint("i", "p")}
+ |id, name, i.price as purchase_price, p.price as sale_price
$extraColList
+ |FROM testcat.ns.$items i $joinType JOIN testcat.ns.$purchases p
+ |ON ${keys.map(k => s"i.${k._1} = p.${k._2}").mkString(" AND ")}
+ |ORDER BY id, purchase_price, sale_price $extraColList
+ |""".stripMargin)
+ }
+
private val customers: String = "customers"
private val customersColumns: Array[Column] = Array(
Column.create("customer_name", StringType),
@@ -912,13 +960,23 @@ class GlutenKeyGroupedPartitioningSuite
s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
s"(2, 11.0, cast('2020-01-01' as timestamp))")
- val df = sql(
- "SELECT id, name, i.price as purchase_price, p.price as sale_price " +
- s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
- "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id,
purchase_price, sale_price")
+ Seq(true, false).foreach {
+ pushDownValues =>
+ withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
pushDownValues.toString) {
+ val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" ->
"time"))
+ val shuffles =
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+ if (pushDownValues) {
+ assert(shuffles.isEmpty, "should not add shuffle when partition
values mismatch")
+ } else {
+ assert(
+ shuffles.nonEmpty,
+ "should add shuffle when partition values mismatch, and " +
+ "pushing down partition values is not enabled")
+ }
- val shuffles =
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
- assert(shuffles.nonEmpty, "should add shuffle when partition keys
mismatch")
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(2, "bb", 10.0,
11.0)))
+ }
+ }
}
testGluten("data source partitioning + dynamic partition filtering") {
@@ -972,4 +1030,796 @@ class GlutenKeyGroupedPartitioningSuite
}
}
}
+
+ testGluten(
+ "SPARK-41471: shuffle one side: only one side reports partitioning with
two identity") {
+ val items_partitions = Array(identity("id"), identity("arrive_time"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" ->
"time"))
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 1, "only shuffle one side not report
partitioning")
+ } else {
+ assert(
+ shuffles.size == 2,
+ "should add two side shuffle when bucketing shuffle one side" +
+ " is not enabled")
+ }
+
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
+ }
+ }
+ }
+
+ testGluten("SPARK-41471: shuffle one side: only one side reports
partitioning") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ val df = createJoinTestDF(Seq("id" -> "item_id"))
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 1, "only shuffle one side not report
partitioning")
+ } else {
+ assert(
+ shuffles.size == 2,
+ "should add two side shuffle when bucketing shuffle one side" +
+ " is not enabled")
+ }
+
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0,
19.5)))
+ }
+ }
+ }
+
+ testGluten("SPARK-41471: shuffle one side: shuffle side has more partition
value") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+ "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+ "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ Seq("", "LEFT OUTER", "RIGHT OUTER", "FULL OUTER").foreach {
+ joinType =>
+ val df = createJoinTestDF(Seq("id" -> "item_id"), joinType =
joinType)
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 1, "only shuffle one side not report
partitioning")
+ } else {
+ assert(
+ shuffles.size == 2,
+ "should add two side shuffle when bucketing shuffle one " +
+ "side is not enabled")
+ }
+ joinType match {
+ case "" =>
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb",
10.0, 19.5)))
+ case "LEFT OUTER" =>
+ checkAnswer(
+ df,
+ Seq(
+ Row(1, "aa", 40.0, 42.0),
+ Row(3, "bb", 10.0, 19.5),
+ Row(4, "cc", 15.5, null)))
+ case "RIGHT OUTER" =>
+ checkAnswer(
+ df,
+ Seq(
+ Row(null, null, null, 26.0),
+ Row(null, null, null, 50.0),
+ Row(1, "aa", 40.0, 42.0),
+ Row(3, "bb", 10.0, 19.5)))
+ case "FULL OUTER" =>
+ checkAnswer(
+ df,
+ Seq(
+ Row(null, null, null, 26.0),
+ Row(null, null, null, 50.0),
+ Row(1, "aa", 40.0, 42.0),
+ Row(3, "bb", 10.0, 19.5),
+ Row(4, "cc", 15.5, null)))
+ }
+ }
+ }
+ }
+ }
+
+ testGluten("SPARK-41471: shuffle one side: partitioning with transform") {
+ val items_partitions = Array(years("arrive_time"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2021-02-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ val df = createJoinTestDF(Seq("arrive_time" -> "time"))
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 1, "partitioning with transform should
trigger SPJ")
+ } else {
+ assert(
+ shuffles.size == 2,
+ "should add two side shuffle when bucketing shuffle one side" +
+ " is not enabled")
+ }
+
+ checkAnswer(
+ df,
+ Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 42.0), Row(4,
"cc", 15.5, 19.5)))
+ }
+ }
+ }
+
+ testGluten(
+ "SPARK-44647: SPJ: test join key is subset of cluster key " +
+ "with push values and partially-clustered") {
+ val table1 = "tab1e1"
+ val table2 = "table2"
+ val partition = Array(identity("id"), identity("data"))
+ createTable(table1, columns, partition)
+ sql(
+ s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+ "(2, 'bb', cast('2020-01-01' as timestamp)), " +
+ "(2, 'cc', cast('2020-01-01' as timestamp)), " +
+ "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+ "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+ "(3, 'ee', cast('2020-01-01' as timestamp)), " +
+ "(3, 'ee', cast('2020-01-01' as timestamp))")
+
+ createTable(table2, columns, partition)
+ sql(
+ s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+ "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+ "(3, 'yy', cast('2020-01-01' as timestamp)), " +
+ "(3, 'yy', cast('2020-01-01' as timestamp)), " +
+ "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+ "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+ "(2, 'ww', cast('2020-01-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ pushDownValues =>
+ Seq(true, false).foreach {
+ filter =>
+ Seq(true, false).foreach {
+ partiallyClustered =>
+ Seq(true, false).foreach {
+ allowJoinKeysSubsetOfPartitionKeys =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key ->
"false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
pushDownValues.toString,
+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+ partiallyClustered.toString,
+ SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key ->
filter.toString,
+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+ allowJoinKeysSubsetOfPartitionKeys.toString
+ ) {
+ val df = sql(s"""
+ |${selectWithMergeJoinHint("t1", "t2")}
+ |t1.id AS id, t1.data AS t1data, t2.data
AS t2data
+ |FROM testcat.ns.$table1 t1 JOIN
testcat.ns.$table2 t2
+ |ON t1.id = t2.id ORDER BY t1.id,
t1data, t2data
+ |""".stripMargin)
+ val shuffles =
+
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+ if (allowJoinKeysSubsetOfPartitionKeys) {
+ assert(shuffles.isEmpty, "SPJ should be triggered")
+ } else {
+ assert(shuffles.nonEmpty, "SPJ should not be
triggered")
+ }
+
+ val scannedPartitions =
collectScans(df.queryExecution.executedPlan)
+ .map(_.inputRDD.partitions.length)
+ (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered,
filter) match {
+ // SPJ, partially-clustered, with filter
+ case (true, true, true) => assert(scannedPartitions ==
Seq(6, 6))
+
+ // SPJ, partially-clustered, no filter
+ case (true, true, false) => assert(scannedPartitions
== Seq(8, 8))
+
+ // SPJ and not partially-clustered, with filter
+ case (true, false, true) => assert(scannedPartitions
== Seq(2, 2))
+
+ // SPJ and not partially-clustered, no filter
+ case (true, false, false) => assert(scannedPartitions
== Seq(4, 4))
+
+ // No SPJ
+ case _ => assert(scannedPartitions == Seq(5, 4))
+ }
+
+ checkAnswer(
+ df,
+ Seq(
+ Row(2, "bb", "ww"),
+ Row(2, "cc", "ww"),
+ Row(3, "dd", "xx"),
+ Row(3, "dd", "xx"),
+ Row(3, "dd", "xx"),
+ Row(3, "dd", "xx"),
+ Row(3, "dd", "yy"),
+ Row(3, "dd", "yy"),
+ Row(3, "dd", "yy"),
+ Row(3, "dd", "yy"),
+ Row(3, "ee", "xx"),
+ Row(3, "ee", "xx"),
+ Row(3, "ee", "xx"),
+ Row(3, "ee", "xx"),
+ Row(3, "ee", "yy"),
+ Row(3, "ee", "yy"),
+ Row(3, "ee", "yy"),
+ Row(3, "ee", "yy")
+ )
+ )
+ }
+ }
+ }
+ }
+ }
+ }
+
+ testGluten("SPARK-44647: test join key is the second cluster key") {
+ val table1 = "tab1e1"
+ val table2 = "table2"
+ val partition = Array(identity("id"), identity("data"))
+ createTable(table1, columns, partition)
+ sql(
+ s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+ "(2, 'bb', cast('2020-01-02' as timestamp)), " +
+ "(3, 'cc', cast('2020-01-03' as timestamp))")
+
+ createTable(table2, columns, partition)
+ sql(
+ s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(4, 'aa', cast('2020-01-01' as timestamp)), " +
+ "(5, 'bb', cast('2020-01-02' as timestamp)), " +
+ "(6, 'cc', cast('2020-01-03' as timestamp))")
+
+ Seq(true, false).foreach {
+ pushDownValues =>
+ Seq(true, false).foreach {
+ partiallyClustered =>
+ Seq(true, false).foreach {
+ allowJoinKeysSubsetOfPartitionKeys =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key ->
"false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
+ pushDownValues.toString,
+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+ partiallyClustered.toString,
+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+ allowJoinKeysSubsetOfPartitionKeys.toString
+ ) {
+
+ val df = sql(s"""
+ |${selectWithMergeJoinHint("t1", "t2")}
+ |t1.id AS t1id, t2.id as t2id, t1.data AS
data
+ |FROM testcat.ns.$table1 t1 JOIN
testcat.ns.$table2 t2
+ |ON t1.data = t2.data
+ |ORDER BY t1id, t1id, data
+ |""".stripMargin)
+ checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3,
6, "cc")))
+
+ val shuffles =
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+ if (allowJoinKeysSubsetOfPartitionKeys) {
+ assert(shuffles.isEmpty, "SPJ should be triggered")
+ } else {
+ assert(shuffles.nonEmpty, "SPJ should not be triggered")
+ }
+
+ val scans = collectScans(df.queryExecution.executedPlan)
+ .map(_.inputRDD.partitions.length)
+ (pushDownValues, allowJoinKeysSubsetOfPartitionKeys,
partiallyClustered) match {
+ // SPJ and partially-clustered
+ case (true, true, true) => assert(scans == Seq(3, 3))
+ // non-SPJ or SPJ/partially-clustered
+ case _ => assert(scans == Seq(3, 3))
+ }
+ }
+ }
+ }
+ }
+ }
+
+ testGluten("SPARK-44647: test join key is the second partition key and a
transform") {
+ val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
+ createTable(items, itemsColumns, items_partitions)
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ val purchases_partitions = Array(bucket(8, "item_id"), days("time"))
+ createTable(purchases, purchasesColumns, purchases_partitions)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 44.0, cast('2020-01-15' as timestamp)), " +
+ s"(1, 45.0, cast('2020-01-15' as timestamp)), " +
+ s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
+ s"(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ pushDownValues =>
+ Seq(true, false).foreach {
+ partiallyClustered =>
+ Seq(true, false).foreach {
+ allowJoinKeysSubsetOfPartitionKeys =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key ->
"false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
pushDownValues.toString,
+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+ partiallyClustered.toString,
+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+ allowJoinKeysSubsetOfPartitionKeys.toString
+ ) {
+ val df =
+ createJoinTestDF(Seq("arrive_time" -> "time"),
extraColumns = Seq("p.item_id"))
+ // Currently SPJ for case where join key not same as
partition key
+ // only supported when push-part-values enabled
+ val shuffles =
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+ if (allowJoinKeysSubsetOfPartitionKeys) {
+ assert(shuffles.isEmpty, "SPJ should be triggered")
+ } else {
+ assert(shuffles.nonEmpty, "SPJ should not be triggered")
+ }
+
+ val scans = collectScans(df.queryExecution.executedPlan)
+ .map(_.inputRDD.partitions.length)
+ (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered)
match {
+ // SPJ and partially-clustered
+ case (true, true) => assert(scans == Seq(5, 5))
+ // SPJ and not partially-clustered
+ case (true, false) => assert(scans == Seq(3, 3))
+ // No SPJ
+ case _ => assert(scans == Seq(4, 4))
+ }
+
+ checkAnswer(
+ df,
+ Seq(
+ Row(1, "aa", 40.0, 11.0, 2),
+ Row(1, "aa", 40.0, 42.0, 1),
+ Row(1, "aa", 41.0, 44.0, 1),
+ Row(1, "aa", 41.0, 45.0, 1),
+ Row(2, "bb", 10.0, 11.0, 2),
+ Row(2, "bb", 10.0, 42.0, 1),
+ Row(2, "bb", 10.5, 11.0, 2),
+ Row(2, "bb", 10.5, 42.0, 1),
+ Row(3, "cc", 15.5, 19.5, 3)
+ )
+ )
+ }
+ }
+ }
+ }
+ }
+
+ testGluten("SPARK-44647: shuffle one side and join keys are less than
partition keys") {
+ val items_partitions = Array(identity("id"), identity("name"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 89.0, cast('2020-01-03' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+ "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+ "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ pushdownValues =>
+ withSQLConf(
+ SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
pushdownValues.toString,
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
"false",
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
"true"
+ ) {
+ val df = createJoinTestDF(Seq("id" -> "item_id"))
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.size == 1, "SPJ should be triggered")
+ checkAnswer(
+ df,
+ Seq(
+ Row(1, "aa", 30.0, 42.0),
+ Row(1, "aa", 30.0, 89.0),
+ Row(1, "aa", 40.0, 42.0),
+ Row(1, "aa", 40.0, 89.0),
+ Row(3, "bb", 10.0, 19.5)))
+ }
+ }
+ }
+
+ testGluten(
+ "SPARK-47094: Compatible buckets does not support SPJ with " +
+ "push-down values or partially-clustered") {
+ val table1 = "tab1e1"
+ val table2 = "table2"
+
+ val partition1 = Array(bucket(4, "store_id"), bucket(2, "dept_id"))
+ val partition2 = Array(bucket(2, "store_id"), bucket(2, "dept_id"))
+
+ createTable(table1, columns2, partition1)
+ sql(
+ s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(0, 0, 'aa'), " +
+ "(1, 1, 'bb'), " +
+ "(2, 2, 'cc')"
+ )
+
+ createTable(table2, columns2, partition2)
+ sql(
+ s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(0, 0, 'aa'), " +
+ "(1, 1, 'bb'), " +
+ "(2, 2, 'cc')"
+ )
+
+ Seq(true, false).foreach {
+ allowPushDown =>
+ Seq(true, false).foreach {
+ partiallyClustered =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
allowPushDown.toString,
+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+ partiallyClustered.toString,
+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true"
+ ) {
+ val df = sql(s"""
+ |${selectWithMergeJoinHint("t1", "t2")}
+ |t1.store_id, t1.store_id, t1.dept_id,
t2.dept_id, t1.data, t2.data
+ |FROM testcat.ns.$table1 t1 JOIN
testcat.ns.$table2 t2
+ |ON t1.store_id = t2.store_id AND t1.dept_id =
t2.dept_id
+ |ORDER BY t1.store_id, t1.dept_id, t1.data,
t2.data
+ |""".stripMargin)
+
+ val shuffles =
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+ val scans =
+
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.partitions.length)
+
+ (allowPushDown, partiallyClustered) match {
+ case (true, false) =>
+ assert(shuffles.isEmpty, "SPJ should be triggered")
+ assert(scans == Seq(2, 2))
+ case (_, _) =>
+ assert(shuffles.nonEmpty, "SPJ should not be triggered")
+ assert(scans == Seq(3, 2))
+ }
+
+ checkAnswer(
+ df,
+ Seq(
+ Row(0, 0, 0, 0, "aa", "aa"),
+ Row(1, 1, 1, 1, "bb", "bb"),
+ Row(2, 2, 2, 2, "cc", "cc")
+ ))
+ }
+ }
+ }
+ }
+
+ testGluten(
+ "SPARK-47094: SPJ: Does not trigger when incompatible number of buckets on
both side") {
+ val table1 = "tab1e1"
+ val table2 = "table2"
+
+ Seq(
+ (2, 3),
+ (3, 4)
+ ).foreach {
+ case (table1buckets1, table2buckets1) =>
+ catalog.clearTables()
+
+ val partition1 = Array(bucket(table1buckets1, "store_id"))
+ val partition2 = Array(bucket(table2buckets1, "store_id"))
+
+ Seq((table1, partition1), (table2, partition2)).foreach {
+ case (tab, part) =>
+ createTable(tab, columns2, part)
+ val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
+ "(0, 0, 'aa'), " +
+ "(1, 0, 'ab'), " + // duplicate partition key
+ "(2, 2, 'ac'), " +
+ "(3, 3, 'ad'), " +
+ "(4, 2, 'bc') "
+
+ sql(insertStr)
+ }
+
+ Seq(true, false).foreach {
+ allowJoinKeysSubsetOfPartitionKeys =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
+ allowJoinKeysSubsetOfPartitionKeys.toString,
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true"
+ ) {
+ val df = sql(s"""
+ |${selectWithMergeJoinHint("t1", "t2")}
+ |t1.store_id, t1.dept_id, t1.data, t2.data
+ |FROM testcat.ns.$table1 t1 JOIN
testcat.ns.$table2 t2
+ |ON t1.store_id = t2.store_id AND t1.dept_id =
t2.dept_id
+ |""".stripMargin)
+
+ val shuffles =
collectColumnarShuffleExchangeExec(df.queryExecution.executedPlan)
+ assert(shuffles.nonEmpty, "SPJ should not be triggered")
+ }
+ }
+ }
+ }
+
+ testGluten("SPARK-48655: order by on partition keys should not introduce
additional shuffle") {
+ val items_partitions = Array(identity("price"), identity("id"))
+ createTable(items, itemsColumns, items_partitions)
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+ s"(null, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+ s"(3, 'cc', null, cast('2020-02-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ sortingEnabled =>
+ withSQLConf(SQLConf.V2_BUCKETING_SORTING_ENABLED.key ->
sortingEnabled.toString) {
+
+ def verifyShuffle(cmd: String, answer: Seq[Row]): Unit = {
+ val df = sql(cmd)
+ if (sortingEnabled) {
+ assert(
+ collectAllShuffles(df.queryExecution.executedPlan).isEmpty,
+ "should contain no shuffle when sorting by partition values")
+ } else {
+ assert(
+ collectAllShuffles(df.queryExecution.executedPlan).size == 1,
+ "should contain one shuffle when optimization is disabled")
+ }
+ checkAnswer(df, answer)
+ }: Unit
+
+ verifyShuffle(
+ s"SELECT price, id FROM testcat.ns.$items ORDER BY price ASC, id
ASC",
+ Seq(
+ Row(null, 3),
+ Row(10.0, 2),
+ Row(15.5, null),
+ Row(15.5, 3),
+ Row(40.0, 1),
+ Row(41.0, 1)))
+
+ verifyShuffle(
+ s"SELECT price, id FROM testcat.ns.$items " +
+ s"ORDER BY price ASC NULLS LAST, id ASC NULLS LAST",
+ Seq(
+ Row(10.0, 2),
+ Row(15.5, 3),
+ Row(15.5, null),
+ Row(40.0, 1),
+ Row(41.0, 1),
+ Row(null, 3))
+ )
+
+ verifyShuffle(
+ s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id
ASC",
+ Seq(
+ Row(41.0, 1),
+ Row(40.0, 1),
+ Row(15.5, null),
+ Row(15.5, 3),
+ Row(10.0, 2),
+ Row(null, 3))
+ )
+
+ verifyShuffle(
+ s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id
DESC",
+ Seq(
+ Row(41.0, 1),
+ Row(40.0, 1),
+ Row(15.5, 3),
+ Row(15.5, null),
+ Row(10.0, 2),
+ Row(null, 3))
+ )
+
+ verifyShuffle(
+ s"SELECT price, id FROM testcat.ns.$items " +
+ s"ORDER BY price DESC NULLS FIRST, id DESC NULLS FIRST",
+ Seq(
+ Row(null, 3),
+ Row(41.0, 1),
+ Row(40.0, 1),
+ Row(15.5, null),
+ Row(15.5, 3),
+ Row(10.0, 2))
+ );
+ }
+ }
+ }
+
+ testGluten("SPARK-48012: one-side shuffle with partition transforms") {
+ val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
+ val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id"))
+
+ Seq(items_partitions, items_partitions2).foreach {
+ partition =>
+ catalog.clearTables()
+
+ createTable(items, itemsColumns, partition)
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " +
+ "(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " +
+ "(5, 'ff', 32.1, cast('2020-03-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(2, 10.7, cast('2020-01-01' as timestamp))," +
+ "(3, 19.5, cast('2020-02-01' as timestamp))," +
+ "(4, 56.5, cast('2020-02-01' as timestamp))")
+
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
+ val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" ->
"time"))
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.size == 1, "only shuffle side that does not report
partitioning")
+
+ checkAnswer(
+ df,
+ Seq(Row(1, "bb", 30.0, 42.0), Row(1, "aa", 40.0, 42.0), Row(4,
"ee", 15.5, 56.5)))
+ }
+ }
+ }
+
+ testGluten("SPARK-48012: one-side shuffle with partition transforms and
pushdown values") {
+ val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 'cc', 30.0, cast('2020-01-02' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(2, 10.7, cast('2020-01-01' as timestamp))")
+
+ Seq(true, false).foreach {
+ pushDown =>
+ {
+ withSQLConf(
+ SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
+ pushDown.toString) {
+ val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" ->
"time"))
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.size == 1, "only shuffle side that does not report
partitioning")
+
+ checkAnswer(df, Seq(Row(1, "bb", 30.0, 42.0), Row(1, "aa", 40.0,
42.0)))
+ }
+ }
+ }
+ }
+
+ testGluten(
+ "SPARK-48012: one-side shuffle with partition transforms " +
+ "with fewer join keys than partition kes") {
+ val items_partitions = Array(bucket(2, "id"), identity("name"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(
+ s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchasesColumns, Array.empty)
+ sql(
+ s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(1, 89.0, cast('2020-01-03' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+ "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+ "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
"false",
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
"true"
+ ) {
+ val df = createJoinTestDF(Seq("id" -> "item_id"))
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.size == 1, "SPJ should be triggered")
+ checkAnswer(
+ df,
+ Seq(
+ Row(1, "aa", 30.0, 42.0),
+ Row(1, "aa", 30.0, 89.0),
+ Row(1, "aa", 40.0, 42.0),
+ Row(1, "aa", 40.0, 89.0),
+ Row(3, "bb", 10.0, 19.5)))
+ }
+ }
+
}
diff --git
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
index 9077fe5abc..247394ba91 100644
---
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
+++
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution,
Distribution, KeyGroupedPartitioning, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution,
Distribution, KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
InternalRowComparableWrapper, TimestampFormatter}
@@ -47,7 +47,7 @@ import org.apache.spark.sql.connector.read.{HasPartitionKey,
InputPartition, Sca
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat,
ParquetFilters}
-import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2ScanExecBase}
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
BatchScanExecShim, DataSourceV2ScanExecBase}
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike,
ShuffleExchangeLike}
@@ -482,10 +482,9 @@ class Spark40Shims extends SparkShims {
applyPartialClustering: Boolean,
replicatePartitions: Boolean,
joinKeyPositions: Option[Seq[Int]] = None): Seq[Seq[InputPartition]] = {
+ val original = batchScan.asInstanceOf[BatchScanExecShim]
scan match {
case _ if keyGroupedPartitioning.isDefined =>
- var finalPartitions = filteredPartitions
-
outputPartitioning match {
case p: KeyGroupedPartitioning =>
assert(keyGroupedPartitioning.isDefined)
@@ -516,8 +515,20 @@ class Spark40Shims extends SparkShims {
}
// Also re-group the partitions if we are reducing compatible
partition expressions
- // TODO: Respect Reducer settings?
- val finalGroupedPartitions = groupedPartitions
+ val finalGroupedPartitions = original.reducers match {
+ case Some(reducers) =>
+ val result = groupedPartitions
+ .groupBy {
+ case (row, _) =>
+ KeyGroupedShuffleSpec.reducePartitionValue(row,
partExpressions, reducers)
+ }
+ .map { case (wrapper, splits) => (wrapper.row,
splits.flatMap(_._2)) }
+ .toSeq
+ val rowOrdering =
+
RowOrdering.createNaturalAscendingOrdering(partExpressions.map(_.dataType))
+ result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
+ case _ => groupedPartitions
+ }
// When partially clustered, the input partitions are not grouped
by partition
// values. Here we'll need to check `commonPartitionValues` and
decide how to group
@@ -587,9 +598,8 @@ class Spark40Shims extends SparkShims {
}
}
- case _ =>
+ case _ => filteredPartitions
}
- finalPartitions
case _ =>
filteredPartitions
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]