MaxGekk commented on code in PR #48774:
URL: https://github.com/apache/spark/pull/48774#discussion_r1834065574
##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala:
##########
@@ -1562,47 +1596,52 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row("aa", 1, "AA
", 2),
Row("aa", 1, "aa", 2)))
)
-
- testCases.foreach(t => {
+ for {
+ t <- testCases
+ broadcastJoinThreshold <- Seq(-1,
SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
+ } {
withTable(t1, t2) {
- sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING
PARQUET")
- sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)")
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key ->
broadcastJoinThreshold.toString) {
+ sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int)
USING PARQUET")
+ sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)")
- sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING
PARQUET")
- sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)")
+ sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int)
USING PARQUET")
+ sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)")
- val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
- checkAnswer(df, t.result)
+ val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
+ checkAnswer(df, t.result)
- val queryPlan = df.queryExecution.executedPlan
+ val queryPlan = df.queryExecution.executedPlan
- // confirm that hash join is used instead of sort merge join
- assert(
- collectFirst(queryPlan) {
- case _: HashJoin => ()
- }.nonEmpty
- )
- assert(
- collectFirst(queryPlan) {
- case _: SortMergeJoinExec => ()
- }.isEmpty
- )
+ // confirm that right kind of join is used.
+ checkRightTypeOfJoinUsed(queryPlan)
- // Only if collation doesn't support binary equality, collation key
should be injected.
- if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
- assert(collectFirst(queryPlan) {
- case b: HashJoin => b.leftKeys.head
- }.head.isInstanceOf[CollationKey])
- } else {
- assert(!collectFirst(queryPlan) {
- case b: HashJoin => b.leftKeys.head
- }.head.isInstanceOf[CollationKey])
+ if (isSortMergeForced) {
+ // Only if collation doesn't support binary equality, collation
key should be injected.
+ if
(!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
+ assert(queryPlan.toString().contains("collationkey"))
+ } else {
+ assert(!queryPlan.toString().contains("collationkey"))
+ }
+ }
+ else {
+ // Only if collation doesn't support binary equality, collation
key should be injected.
+ if
(!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
+ assert(collectFirst(queryPlan) {
Review Comment:
Can you use `find`?
##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala:
##########
@@ -1562,47 +1596,52 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row("aa", 1, "AA
", 2),
Row("aa", 1, "aa", 2)))
)
-
- testCases.foreach(t => {
+ for {
+ t <- testCases
+ broadcastJoinThreshold <- Seq(-1,
SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
+ } {
withTable(t1, t2) {
- sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING
PARQUET")
- sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)")
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key ->
broadcastJoinThreshold.toString) {
+ sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int)
USING PARQUET")
+ sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)")
- sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING
PARQUET")
- sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)")
+ sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int)
USING PARQUET")
+ sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)")
- val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
- checkAnswer(df, t.result)
+ val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
+ checkAnswer(df, t.result)
- val queryPlan = df.queryExecution.executedPlan
+ val queryPlan = df.queryExecution.executedPlan
- // confirm that hash join is used instead of sort merge join
- assert(
- collectFirst(queryPlan) {
- case _: HashJoin => ()
- }.nonEmpty
- )
- assert(
- collectFirst(queryPlan) {
- case _: SortMergeJoinExec => ()
- }.isEmpty
- )
+ // confirm that right kind of join is used.
+ checkRightTypeOfJoinUsed(queryPlan)
- // Only if collation doesn't support binary equality, collation key
should be injected.
- if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
- assert(collectFirst(queryPlan) {
- case b: HashJoin => b.leftKeys.head
- }.head.isInstanceOf[CollationKey])
- } else {
- assert(!collectFirst(queryPlan) {
- case b: HashJoin => b.leftKeys.head
- }.head.isInstanceOf[CollationKey])
+ if (isSortMergeForced) {
+ // Only if collation doesn't support binary equality, collation
key should be injected.
+ if
(!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
Review Comment:
BTW, the `RewriteCollationJoin` checks `supportsBinaryEquality` which is a
more stronger condition:
```
supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType
```
Could you allign the test to `RewriteCollationJoin`.
##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala:
##########
@@ -43,6 +44,39 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
private val collationNonPreservingSources = Seq("orc", "csv", "json", "text")
private val allFileBasedDataSources = collationPreservingSources ++
collationNonPreservingSources
+ @inline
+ private def isSortMergeForced: Boolean = {
+ SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD) == -1
+ }
+
+ private def checkRightTypeOfJoinUsed(queryPlan: SparkPlan): Unit = {
+ assert(
+ // If sort merge join is forced, we should not see HashJoin in the plan.
+ isSortMergeForced ||
+ // If sort merge join is not forced, we should see HashJoin in the plan
+ // and not SortMergeJoin.
+ collectFirst(queryPlan) {
+ case _: HashJoin => ()
+ }.nonEmpty &&
+ collectFirst(queryPlan) {
+ case _: SortMergeJoinExec => ()
+ }.isEmpty
+ )
+
+ assert(
+ // If sort merge join is not forced, we should not see SortMergeJoin in
the plan.
+ !isSortMergeForced ||
+ // If sort merge join is forced, we should see SortMergeJoin in the
plan
+ // and not HashJoin.
+ collectFirst(queryPlan) {
+ case _: HashJoin => ()
+ }.isEmpty &&
+ collectFirst(queryPlan) {
+ case _: SortMergeJoinExec => ()
+ }.nonEmpty
+ )
+ }
Review Comment:
Could you simplify the checks:
```suggestion
private def checkRightTypeOfJoinUsed(queryPlan: SparkPlan): Unit = {
foreach(queryPlan) {
case _: HashJoin => assert(!isSortMergeForced)
case _: SortMergeJoinExec => assert(isSortMergeForced)
case _ => ()
}
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]