cloud-fan commented on a change in pull request #28676:
URL: https://github.com/apache/spark/pull/28676#discussion_r455712732



##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
##########
@@ -415,6 +417,216 @@ abstract class BroadcastJoinSuiteBase extends QueryTest 
with SQLTestUtils
       assert(e.getMessage.contains(s"Could not execute broadcast in $timeout 
secs."))
     }
   }
+
+  test("broadcast join where streamed side's output partitioning is 
HashPartitioning") {
+    withTable("t1", "t3") {
+      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
+        val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
+        val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2")
+        val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3")
+        df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1")
+        df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3")
+        val t1 = spark.table("t1")
+        val t3 = spark.table("t3")
+
+        // join1 is a broadcast join where df2 is broadcasted. Note that 
output partitioning on the
+        // streamed side (t1) is HashPartitioning (bucketed files).
+        val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === 
df2("j2"))
+        val plan1 = join1.queryExecution.executedPlan
+        assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty)
+        val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => 
b }
+        assert(broadcastJoins.size == 1)
+        
assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection])
+        val p = 
broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection]
+        assert(p.partitionings.size == 4)
+        // Verify all the combinations of output partitioning.
+        Seq(Seq(t1("i1"), t1("j1")),
+          Seq(t1("i1"), df2("j2")),
+          Seq(df2("i2"), t1("j1")),
+          Seq(df2("i2"), df2("j2"))).foreach { expected =>
+          val expectedExpressions = expected.map(_.expr)
+          assert(p.partitionings.exists {
+            case h: HashPartitioning => expressionsEqual(h.expressions, 
expectedExpressions)
+          })
+        }
+
+        // Join on the column from the broadcasted side (i2, j2) and make sure 
output partitioning
+        // is maintained by checking no shuffle exchange is introduced.
+        val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === 
t3("j3"))
+        val plan2 = join2.queryExecution.executedPlan
+        assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
+        assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1)
+        assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty)
+
+        // Validate the data with broadcast join off.
+        withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+          val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === 
t3("j3"))
+          QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq)

Review comment:
       `checkAnswer` supports comparing two dataframes. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to