c21 commented on a change in pull request #31708:
URL: https://github.com/apache/spark/pull/31708#discussion_r585418496



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
##########
@@ -238,7 +250,6 @@ case class BroadcastNestedLoopJoinExec(
    *   ExistenceJoin with BuildLeft
    */
   private def defaultJoin(relation: Broadcast[Array[InternalRow]]): 
RDD[InternalRow] = {
-    /** All rows that either match both-way, or rows from streamed joined with 
nulls. */

Review comment:
       This comment is quite confusing for me, so delete here.
   I am not sure how this can make sense, given e.g. `FullOuter` join in this 
method, will contain rows match both-way, and rows from streamed AND build side 
with nulls. Not only streamed side.
   

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1296,4 +1296,92 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
       }
     }
   }
+
+  test("SPARK-34593: Preserve broadcast nested loop join partitioning and 
ordering") {
+    withTable("t1", "t2", "t3", "t4", "t5") {
+      spark.range(15).toDF("k").write.bucketBy(4, "k").saveAsTable("t1")
+      spark.range(6).toDF("k").write.bucketBy(4, "k").saveAsTable("t2")
+      spark.range(8).toDF("k").write.saveAsTable("t3")
+      spark.range(9).toDF("k").write.saveAsTable("t4")
+      spark.range(11).toDF("k").write.saveAsTable("t5")
+
+      def getAggQuery(selectExpr: String, joinType: String): String = {
+        s"""
+           |SELECT k, COUNT(*)
+           |FROM (SELECT $selectExpr FROM t1 $joinType JOIN t2)
+           |GROUP BY k
+         """.stripMargin
+      }
+
+      // Test output partitioning is preserved
+      Seq("INNER", "LEFT OUTER", "RIGHT OUTER", "LEFT SEMI", "LEFT 
ANTI").foreach {
+        joinType =>
+          val selectExpr = if (joinType == "RIGHT OUTER") {
+            "/*+ BROADCAST(t1) */ t2.k AS k"
+          } else {
+            "/*+ BROADCAST(t2) */ t1.k as k"
+          }
+          val plan = sql(getAggQuery(selectExpr, 
joinType)).queryExecution.executedPlan
+          assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
+          // No extra shuffle before aggregation
+          assert(collect(plan) { case _: ShuffleExchangeExec => true }.size 
=== 0)
+      }
+
+      // Test output partitioning is not preserved
+      Seq("LEFT OUTER", "RIGHT OUTER", "LEFT SEMI", "LEFT ANTI", "FULL 
OUTER").foreach {
+        joinType =>
+          val selectExpr = if (joinType == "RIGHT OUTER") {
+            "/*+ BROADCAST(t2) */ t1.k AS k"
+          } else {
+            "/*+ BROADCAST(t1) */ t1.k as k"
+          }
+          val plan = sql(getAggQuery(selectExpr, 
joinType)).queryExecution.executedPlan
+          assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
+          // Have shuffle before aggregation
+          assert(collect(plan) { case _: ShuffleExchangeExec => true }.size 
=== 1)
+      }
+
+      def getJoinQuery(selectExpr: String, joinType: String): String = {
+        s"""
+           |SELECT /*+ MERGE(t3) */ t3.k
+           |FROM
+           |(
+           |  SELECT $selectExpr
+           |  FROM
+           |    (SELECT /*+ MERGE(t4) */ t1.k AS k1 FROM t1 JOIN t4 ON t1.k = 
t4.k) AS left_t
+           |  $joinType JOIN
+           |    (SELECT /*+ MERGE(t5) */ t2.k AS k2 FROM t2 JOIN t5 ON t2.k = 
t5.k) AS right_t
+           |)
+           |JOIN t3
+           |ON t3.k = k0
+         """.stripMargin
+      }
+
+      // Test output ordering is preserved
+      Seq("INNER", "LEFT OUTER", "RIGHT OUTER", "LEFT SEMI", "LEFT 
ANTI").foreach {
+        joinType =>
+          val selectExpr = if (joinType == "RIGHT OUTER") {
+            "/*+ BROADCAST(left_t) */ k2 AS k0"
+          } else {
+            "/*+ BROADCAST(right_t) */ k1 as k0"
+          }
+          val plan = sql(getJoinQuery(selectExpr, 
joinType)).queryExecution.executedPlan
+          assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
+          assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 
3)
+          // No extra sort on left side before last sort merge join
+          assert(collect(plan) { case _: SortExec => true }.size === 5)
+      }
+
+      // Test output ordering is not preserved
+      Seq("LEFT OUTER", "FULL OUTER").foreach {

Review comment:
       Test for `RIGHT OUTER`, `LEFT SEMI`, `LEFT ANTI` is omitted here, as I 
cannot reproduce a valid test case. The broadcast nested loop join always 
changes the join type to `INNER` join, or reorder somehow. I think there's some 
optimization rule are taking effect.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to