cloud-fan commented on code in PR #46599:
URL: https://github.com/apache/spark/pull/46599#discussion_r1617637707
##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala:
##########
@@ -1030,6 +999,135 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
}
+ test("hash join should be used for collated strings") {
+ val t1 = "T_1"
+ val t2 = "T_2"
+
+ case class HashJoinTestCase[R](collation: String, result: R)
+ val testCases = Seq(
+ HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))),
+ HashJoinTestCase("UTF8_BINARY_LCASE", Seq(Row("aa", 1, "AA", 2),
Row("aa", 1, "aa", 2))),
+ HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))),
+ HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1,
"aa", 2)))
+ )
+
+ testCases.foreach(t => {
+ withTable(t1, t2) {
+ sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING
PARQUET")
+ sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
+
+ sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING
PARQUET")
+ sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
+
+ val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
+ checkAnswer(df, t.result)
+
+ val queryPlan = df.queryExecution.executedPlan
+
+ // confirm that hash join is used instead of sort merge join
+ assert(
+ collectFirst(queryPlan) {
+ case _: BroadcastHashJoinExec => ()
+ }.nonEmpty
+ )
+ assert(
+ collectFirst(queryPlan) {
+ case _: SortMergeJoinExec => ()
+ }.isEmpty
+ )
+
+ // if collation doesn't support binary equality, collation key should
be injected
+ if
(!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) {
+ assert(collectFirst(queryPlan) {
+ case b: BroadcastHashJoinExec => b.leftKeys.head
+ }.head.isInstanceOf[CollationKey])
+ }
+ }
+ })
+ }
+
+ test("rewrite with collationkey should be an excludable rule") {
+ val t1 = "T_1"
+ val t2 = "T_2"
+ val collation = "UTF8_BINARY_LCASE"
+ val collationRewriteJoinRule =
"org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin"
+ withTable(t1, t2) {
+ withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
collationRewriteJoinRule) {
+ sql(s"CREATE TABLE $t1 (x STRING COLLATE $collation, i int) USING
PARQUET")
+ sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
+
+ sql(s"CREATE TABLE $t2 (y STRING COLLATE $collation, j int) USING
PARQUET")
+ sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
+
+ val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
+ checkAnswer(df, Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2)))
+
+ val queryPlan = df.queryExecution.executedPlan
+
+ // confirm that shuffle join is used instead of hash join
+ assert(
+ collectFirst(queryPlan) {
+ case _: BroadcastHashJoinExec => ()
+ }.isEmpty
+ )
+ assert(
+ collectFirst(queryPlan) {
+ case _: SortMergeJoinExec => ()
+ }.nonEmpty
+ )
+ }
+ }
+ }
+
+ test("rewrite with collationkey shouldn't disrupt multiple join conditions")
{
+ val t1 = "T_1"
+ val t2 = "T_2"
+
+ case class HashMultiJoinTestCase[R](
+ type1: String,
+ type2: String,
+ data1: String,
+ data2: String,
+ result: R
+ )
+ val testCases = Seq(
+ HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "INT",
+ "'a', 0, 1", "'a', 0, 1", Row("a", 0, 1, "a", 0, 1)),
+ HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE
UTF8_BINARY",
+ "'a', 'a', 1", "'a', 'a', 1", Row("a", "a", 1, "a", "a", 1)),
+ HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE
UTF8_BINARY_LCASE",
+ "'a', 'a', 1", "'a', 'A', 1", Row("a", "a", 1, "a", "A", 1)),
+ HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY_LCASE", "STRING
COLLATE UNICODE_CI",
+ "'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1))
+ )
+
+ testCases.foreach(t => {
+ withTable(t1, t2) {
+ sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING
PARQUET")
+ sql(s"INSERT INTO $t1 VALUES (${t.data1})")
+ sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING
PARQUET")
+ sql(s"INSERT INTO $t2 VALUES (${t.data2})")
+
+ val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y =
$t2.y")
+ checkAnswer(df, t.result)
+
+ val queryPlan = df.queryExecution.executedPlan
+
+ // confirm that hash join is used instead of sort merge join
+ assert(
+ collectFirst(queryPlan) {
+ case _: BroadcastHashJoinExec => ()
Review Comment:
ditto
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]