imback82 commented on a change in pull request #28676:
URL: https://github.com/apache/spark/pull/28676#discussion_r451259973
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
##########
@@ -415,6 +417,192 @@ abstract class BroadcastJoinSuiteBase extends QueryTest
with SQLTestUtils
assert(e.getMessage.contains(s"Could not execute broadcast in $timeout
secs."))
}
}
+
+ test("broadcast join where streamed side's output partitioning is
HashPartitioning") {
+ withTable("t1", "t3") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
+ val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
+ val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2")
+ val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3")
+ df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1")
+ df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3")
+ val t1 = spark.table("t1")
+ val t3 = spark.table("t3")
+
+ // join1 is a broadcast join where df2 is broadcasted. Note that
output partitioning on the
+ // streamed side (t1) is HashPartitioning (bucketed files).
+ val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") ===
df2("j2"))
+ val plan1 = join1.queryExecution.executedPlan
+ assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty)
+ val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec =>
b }
+ assert(broadcastJoins.size == 1)
+ broadcastJoins(0).outputPartitioning match {
+ case p: PartitioningCollection =>
+ assert(p.partitionings.size == 4)
+ // Verify all the combinations of output partitioning.
+ Seq(Seq(t1("i1"), t1("j1")),
+ Seq(t1("i1"), df2("j2")),
+ Seq(df2("i2"), t1("j1")),
+ Seq(df2("i2"), df2("j2"))).foreach { expected =>
+ val expectedExpressions = expected.map(_.expr)
+ assert(p.partitionings.exists {
+ case h: HashPartitioning => expressionsEqual(h.expressions,
expectedExpressions)
+ })
+ }
+ case _ => fail()
+ }
+
+ // Join on the column from the broadcasted side (i2, j2) and make sure
output partitioning
+ // is maintained by checking no shuffle exchange is introduced.
+ val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") ===
t3("j3"))
+ val plan2 = join2.queryExecution.executedPlan
+ assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
+ assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1)
+ assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty)
+
+ // Validate the data with broadcast join off.
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") ===
t3("j3"))
+ QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq)
+ }
+ }
+ }
+ }
+
+ test("broadcast join where streamed side's output partitioning is
PartitioningCollection") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
+ val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
+ val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2")
+ val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3")
+ val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4")
+
+ // join1 is a sort merge join (shuffle on the both sides).
+ val join1 = t1.join(t2, t1("i1") === t2("i2"))
+ val plan1 = join1.queryExecution.executedPlan
+ assert(collect(plan1) { case s: SortMergeJoinExec => s }.size == 1)
+ assert(collect(plan1) { case e: ShuffleExchangeExec => e }.size == 2)
+
+ // join2 is a broadcast join where t3 is broadcasted. Note that output
partitioning on the
+ // streamed side (join1) is PartitioningCollection (sort merge join)
+ val join2 = join1.join(t3, join1("i1") === t3("i3"))
+ val plan2 = join2.queryExecution.executedPlan
+ assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
+ assert(collect(plan2) { case e: ShuffleExchangeExec => e }.size == 2)
+ val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b
}
+ assert(broadcastJoins.size == 1)
+ broadcastJoins(0).outputPartitioning match {
+ case p: PartitioningCollection =>
+ assert(p.partitionings.size == 3)
+ // Verify all the combinations of output partitioning.
+ Seq(Seq(t1("i1")), Seq(t2("i2")), Seq(t3("i3"))).foreach { expected
=>
+ val expectedExpressions = expected.map(_.expr)
+ assert(p.partitionings.exists {
+ case h: HashPartitioning => expressionsEqual(h.expressions,
expectedExpressions)
+ })
+ }
+ case _ => fail()
Review comment:
Changed as suggested.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]