advancedxy commented on code in PR #45267:
URL: https://github.com/apache/spark/pull/45267#discussion_r1505905928
##########
sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala:
##########
@@ -1310,6 +1314,312 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
}
}
+ test("SPARK-47094: Support compatible buckets") {
+ val table1 = "tab1e1"
+ val table2 = "table2"
+
+ Seq(
+ ((2, 4), (4, 2)),
+ ((4, 2), (2, 4)),
+ ((2, 2), (4, 6)),
+ ((6, 2), (2, 2))).foreach {
+ case ((table1buckets1, table1buckets2), (table2buckets1,
table2buckets2)) =>
+ catalog.clearTables()
+
+ val partition1 = Array(bucket(table1buckets1, "store_id"),
+ bucket(table1buckets2, "dept_id"))
+ val partition2 = Array(bucket(table2buckets1, "store_id"),
+ bucket(table2buckets2, "dept_id"))
+
+ Seq((table1, partition1), (table2, partition2)).foreach { case (tab,
part) =>
+ createTable(tab, schema2, part)
+ val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
+ "(0, 0, 'aa'), " +
+ "(0, 0, 'ab'), " + // duplicate partition key
+ "(0, 1, 'ac'), " +
+ "(0, 2, 'ad'), " +
+ "(0, 3, 'ae'), " +
+ "(0, 4, 'af'), " +
+ "(0, 5, 'ag'), " +
+ "(1, 0, 'ah'), " +
+ "(1, 0, 'ai'), " + // duplicate partition key
+ "(1, 1, 'aj'), " +
+ "(1, 2, 'ak'), " +
+ "(1, 3, 'al'), " +
+ "(1, 4, 'am'), " +
+ "(1, 5, 'an'), " +
+ "(2, 0, 'ao'), " +
+ "(2, 0, 'ap'), " + // duplicate partition key
+ "(2, 1, 'aq'), " +
+ "(2, 2, 'ar'), " +
+ "(2, 3, 'as'), " +
+ "(2, 4, 'at'), " +
+ "(2, 5, 'au'), " +
+ "(3, 0, 'av'), " +
+ "(3, 0, 'aw'), " + // duplicate partition key
+ "(3, 1, 'ax'), " +
+ "(3, 2, 'ay'), " +
+ "(3, 3, 'az'), " +
+ "(3, 4, 'ba'), " +
+ "(3, 5, 'bb'), " +
+ "(4, 0, 'bc'), " +
+ "(4, 0, 'bd'), " + // duplicate partition key
+ "(4, 1, 'be'), " +
+ "(4, 2, 'bf'), " +
+ "(4, 3, 'bg'), " +
+ "(4, 4, 'bh'), " +
+ "(4, 5, 'bi'), " +
+ "(5, 0, 'bj'), " +
+ "(5, 0, 'bk'), " + // duplicate partition key
+ "(5, 1, 'bl'), " +
+ "(5, 2, 'bm'), " +
+ "(5, 3, 'bn'), " +
+ "(5, 4, 'bo'), " +
+ "(5, 5, 'bp')"
+
+ // additional unmatched partitions to test push down
+ val finalStr = if (tab == table1) {
+ insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
+ } else {
+ insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
+ }
+
+ sql(finalStr)
+ }
+
+ Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key
-> "false",
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key
->
+ allowJoinKeysSubsetOfPartitionKeys.toString,
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint("t1", "t2")}
+ |t1.store_id, t1.dept_id, t1.data, t2.data
+ |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+ |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
+ |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty, "SPJ should be triggered")
+
+ val scans =
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+ partitions.length)
+ val expectedBuckets = Math.min(table1buckets1, table2buckets1) *
+ Math.min(table1buckets2, table2buckets2)
+ assert(scans == Seq(expectedBuckets, expectedBuckets))
+
+ checkAnswer(df, Seq(
+ Row(0, 0, "aa", "aa"),
+ Row(0, 0, "aa", "ab"),
+ Row(0, 0, "ab", "aa"),
+ Row(0, 0, "ab", "ab"),
+ Row(0, 1, "ac", "ac"),
+ Row(0, 2, "ad", "ad"),
+ Row(0, 3, "ae", "ae"),
+ Row(0, 4, "af", "af"),
+ Row(0, 5, "ag", "ag"),
+ Row(1, 0, "ah", "ah"),
+ Row(1, 0, "ah", "ai"),
+ Row(1, 0, "ai", "ah"),
+ Row(1, 0, "ai", "ai"),
+ Row(1, 1, "aj", "aj"),
+ Row(1, 2, "ak", "ak"),
+ Row(1, 3, "al", "al"),
+ Row(1, 4, "am", "am"),
+ Row(1, 5, "an", "an"),
+ Row(2, 0, "ao", "ao"),
+ Row(2, 0, "ao", "ap"),
+ Row(2, 0, "ap", "ao"),
+ Row(2, 0, "ap", "ap"),
+ Row(2, 1, "aq", "aq"),
+ Row(2, 2, "ar", "ar"),
+ Row(2, 3, "as", "as"),
+ Row(2, 4, "at", "at"),
+ Row(2, 5, "au", "au"),
+ Row(3, 0, "av", "av"),
+ Row(3, 0, "av", "aw"),
+ Row(3, 0, "aw", "av"),
+ Row(3, 0, "aw", "aw"),
+ Row(3, 1, "ax", "ax"),
+ Row(3, 2, "ay", "ay"),
+ Row(3, 3, "az", "az"),
+ Row(3, 4, "ba", "ba"),
+ Row(3, 5, "bb", "bb"),
+ Row(4, 0, "bc", "bc"),
+ Row(4, 0, "bc", "bd"),
+ Row(4, 0, "bd", "bc"),
+ Row(4, 0, "bd", "bd"),
+ Row(4, 1, "be", "be"),
+ Row(4, 2, "bf", "bf"),
+ Row(4, 3, "bg", "bg"),
+ Row(4, 4, "bh", "bh"),
+ Row(4, 5, "bi", "bi"),
+ Row(5, 0, "bj", "bj"),
+ Row(5, 0, "bj", "bk"),
+ Row(5, 0, "bk", "bj"),
+ Row(5, 0, "bk", "bk"),
+ Row(5, 1, "bl", "bl"),
+ Row(5, 2, "bm", "bm"),
+ Row(5, 3, "bn", "bn"),
+ Row(5, 4, "bo", "bo"),
+ Row(5, 5, "bp", "bp")
+ ))
+ }
+ }
+ }
+ }
+
+ test("SPARK-47094: Support compatible buckets with less join keys than
partition keys") {
+ val table1 = "tab1e1"
+ val table2 = "table2"
+
+ Seq((2, 4), (4, 2), (2, 6), (6, 2)).foreach {
+ case (table1buckets, table2buckets) =>
+ catalog.clearTables()
+
+ val partition1 = Array(identity("data"),
+ bucket(table1buckets, "dept_id"))
+ val partition2 = Array(bucket(3, "store_id"),
+ bucket(table2buckets, "dept_id"))
+
+ createTable(table1, schema2, partition1)
+ sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(0, 0, 'aa'), " +
+ "(1, 0, 'ab'), " +
+ "(2, 1, 'ac'), " +
+ "(3, 2, 'ad'), " +
+ "(4, 3, 'ae'), " +
+ "(5, 4, 'af'), " +
+ "(6, 5, 'ag'), " +
+
+ // value without other side match
+ "(6, 6, 'xx')"
+ )
+
+ createTable(table2, schema2, partition2)
+ sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(6, 0, '01'), " +
+ "(5, 1, '02'), " + // duplicate partition key
+ "(5, 1, '03'), " +
+ "(4, 2, '04'), " +
+ "(3, 3, '05'), " +
+ "(2, 4, '06'), " +
+ "(1, 5, '07'), " +
+
+ // value without other side match
+ "(7, 7, '99')"
+ )
+
+
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
"false",
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
"true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint("t1", "t2")}
+ |t1.store_id, t2.store_id, t1.dept_id, t2.dept_id, t1.data,
t2.data
+ |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+ |ON t1.dept_id = t2.dept_id
+ |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty, "SPJ should be triggered")
+
+ val scans =
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+ partitions.length)
+
+ val expectedBuckets = Math.min(table1buckets, table2buckets)
Review Comment:
Looks like the bucket numbers are always coalesced to a smaller in the
current impl?
It might not be the desired behavior when the coalesced bucket is extreme
small, like 1 or 2 or 4...
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]