szehon-ho commented on code in PR #45267:
URL: https://github.com/apache/spark/pull/45267#discussion_r1506338476


##########
sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala:
##########
@@ -1310,6 +1314,312 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     }
   }
 
+  test("SPARK-47094: Support compatible buckets") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    Seq(
+      ((2, 4), (4, 2)),
+      ((4, 2), (2, 4)),
+      ((2, 2), (4, 6)),
+      ((6, 2), (2, 2))).foreach {
+      case ((table1buckets1, table1buckets2), (table2buckets1, 
table2buckets2)) =>
+        catalog.clearTables()
+
+        val partition1 = Array(bucket(table1buckets1, "store_id"),
+          bucket(table1buckets2, "dept_id"))
+        val partition2 = Array(bucket(table2buckets1, "store_id"),
+          bucket(table2buckets2, "dept_id"))
+
+        Seq((table1, partition1), (table2, partition2)).foreach { case (tab, 
part) =>
+          createTable(tab, schema2, part)
+          val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
+            "(0, 0, 'aa'), " +
+            "(0, 0, 'ab'), " + // duplicate partition key
+            "(0, 1, 'ac'), " +
+            "(0, 2, 'ad'), " +
+            "(0, 3, 'ae'), " +
+            "(0, 4, 'af'), " +
+            "(0, 5, 'ag'), " +
+            "(1, 0, 'ah'), " +
+            "(1, 0, 'ai'), " + // duplicate partition key
+            "(1, 1, 'aj'), " +
+            "(1, 2, 'ak'), " +
+            "(1, 3, 'al'), " +
+            "(1, 4, 'am'), " +
+            "(1, 5, 'an'), " +
+            "(2, 0, 'ao'), " +
+            "(2, 0, 'ap'), " + // duplicate partition key
+            "(2, 1, 'aq'), " +
+            "(2, 2, 'ar'), " +
+            "(2, 3, 'as'), " +
+            "(2, 4, 'at'), " +
+            "(2, 5, 'au'), " +
+            "(3, 0, 'av'), " +
+            "(3, 0, 'aw'), " + // duplicate partition key
+            "(3, 1, 'ax'), " +
+            "(3, 2, 'ay'), " +
+            "(3, 3, 'az'), " +
+            "(3, 4, 'ba'), " +
+            "(3, 5, 'bb'), " +
+            "(4, 0, 'bc'), " +
+            "(4, 0, 'bd'), " + // duplicate partition key
+            "(4, 1, 'be'), " +
+            "(4, 2, 'bf'), " +
+            "(4, 3, 'bg'), " +
+            "(4, 4, 'bh'), " +
+            "(4, 5, 'bi'), " +
+            "(5, 0, 'bj'), " +
+            "(5, 0, 'bk'), " + // duplicate partition key
+            "(5, 1, 'bl'), " +
+            "(5, 2, 'bm'), " +
+            "(5, 3, 'bn'), " +
+            "(5, 4, 'bo'), " +
+            "(5, 5, 'bp')"
+
+            // additional unmatched partitions to test push down
+            val finalStr = if (tab == table1) {
+              insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
+            } else {
+              insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
+            }
+
+            sql(finalStr)
+        }
+
+        Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key 
-> "false",
+            SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key 
->
+              allowJoinKeysSubsetOfPartitionKeys.toString,
+            SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+            val df = sql(
+              s"""
+                 |${selectWithMergeJoinHint("t1", "t2")}
+                 |t1.store_id, t1.dept_id, t1.data, t2.data
+                 |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+                 |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
+                 |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+                 |""".stripMargin)
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            assert(shuffles.isEmpty, "SPJ should be triggered")
+
+            val scans = 
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+              partitions.length)
+            val expectedBuckets = Math.min(table1buckets1, table2buckets1) *
+              Math.min(table1buckets2, table2buckets2)
+            assert(scans == Seq(expectedBuckets, expectedBuckets))
+
+            checkAnswer(df, Seq(
+              Row(0, 0, "aa", "aa"),
+              Row(0, 0, "aa", "ab"),
+              Row(0, 0, "ab", "aa"),
+              Row(0, 0, "ab", "ab"),
+              Row(0, 1, "ac", "ac"),
+              Row(0, 2, "ad", "ad"),
+              Row(0, 3, "ae", "ae"),
+              Row(0, 4, "af", "af"),
+              Row(0, 5, "ag", "ag"),
+              Row(1, 0, "ah", "ah"),
+              Row(1, 0, "ah", "ai"),
+              Row(1, 0, "ai", "ah"),
+              Row(1, 0, "ai", "ai"),
+              Row(1, 1, "aj", "aj"),
+              Row(1, 2, "ak", "ak"),
+              Row(1, 3, "al", "al"),
+              Row(1, 4, "am", "am"),
+              Row(1, 5, "an", "an"),
+              Row(2, 0, "ao", "ao"),
+              Row(2, 0, "ao", "ap"),
+              Row(2, 0, "ap", "ao"),
+              Row(2, 0, "ap", "ap"),
+              Row(2, 1, "aq", "aq"),
+              Row(2, 2, "ar", "ar"),
+              Row(2, 3, "as", "as"),
+              Row(2, 4, "at", "at"),
+              Row(2, 5, "au", "au"),
+              Row(3, 0, "av", "av"),
+              Row(3, 0, "av", "aw"),
+              Row(3, 0, "aw", "av"),
+              Row(3, 0, "aw", "aw"),
+              Row(3, 1, "ax", "ax"),
+              Row(3, 2, "ay", "ay"),
+              Row(3, 3, "az", "az"),
+              Row(3, 4, "ba", "ba"),
+              Row(3, 5, "bb", "bb"),
+              Row(4, 0, "bc", "bc"),
+              Row(4, 0, "bc", "bd"),
+              Row(4, 0, "bd", "bc"),
+              Row(4, 0, "bd", "bd"),
+              Row(4, 1, "be", "be"),
+              Row(4, 2, "bf", "bf"),
+              Row(4, 3, "bg", "bg"),
+              Row(4, 4, "bh", "bh"),
+              Row(4, 5, "bi", "bi"),
+              Row(5, 0, "bj", "bj"),
+              Row(5, 0, "bj", "bk"),
+              Row(5, 0, "bk", "bj"),
+              Row(5, 0, "bk", "bk"),
+              Row(5, 1, "bl", "bl"),
+              Row(5, 2, "bm", "bm"),
+              Row(5, 3, "bn", "bn"),
+              Row(5, 4, "bo", "bo"),
+              Row(5, 5, "bp", "bp")
+            ))
+          }
+        }
+    }
+  }
+
+  test("SPARK-47094: Support compatible buckets with less join keys than 
partition keys") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+
+    Seq((2, 4), (4, 2), (2, 6), (6, 2)).foreach {
+      case (table1buckets, table2buckets) =>
+        catalog.clearTables()
+
+        val partition1 = Array(identity("data"),
+          bucket(table1buckets, "dept_id"))
+        val partition2 = Array(bucket(3, "store_id"),
+          bucket(table2buckets, "dept_id"))
+
+        createTable(table1, schema2, partition1)
+        sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+          "(0, 0, 'aa'), " +
+          "(1, 0, 'ab'), " +
+          "(2, 1, 'ac'), " +
+          "(3, 2, 'ad'), " +
+          "(4, 3, 'ae'), " +
+          "(5, 4, 'af'), " +
+          "(6, 5, 'ag'), " +
+
+          // value without other side match
+          "(6, 6, 'xx')"
+        )
+
+        createTable(table2, schema2, partition2)
+        sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+          "(6, 0, '01'), " +
+          "(5, 1, '02'), " + // duplicate partition key
+          "(5, 1, '03'), " +
+          "(4, 2, '04'), " +
+          "(3, 3, '05'), " +
+          "(2, 4, '06'), " +
+          "(1, 5, '07'), " +
+
+          // value without other side match
+          "(7, 7, '99')"
+        )
+
+
+        withSQLConf(
+          SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+          SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true",
+          SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+          val df = sql(
+            s"""
+               |${selectWithMergeJoinHint("t1", "t2")}
+               |t1.store_id, t2.store_id, t1.dept_id, t2.dept_id, t1.data, 
t2.data
+               |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
+               |ON t1.dept_id = t2.dept_id
+               |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
+               |""".stripMargin)
+
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.isEmpty, "SPJ should be triggered")
+
+          val scans = 
collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
+            partitions.length)
+
+          val expectedBuckets = Math.min(table1buckets, table2buckets)

Review Comment:
   Yes, in fact it is definitely a follow up to do this with 
'partiallyClustered' (which is currently not enabled).  This mode today 
'duplicate' partitions for the partially clustered side, and I think can be 
used to turn on the same for compatible transforms.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to