cloud-fan commented on a change in pull request #28676: URL: https://github.com/apache/spark/pull/28676#discussion_r455712732
########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala ########## @@ -415,6 +417,216 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs.")) } } + + test("broadcast join where streamed side's output partitioning is HashPartitioning") { + withTable("t1", "t3") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") + val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2") + val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") + df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1") + df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3") + val t1 = spark.table("t1") + val t3 = spark.table("t3") + + // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the + // streamed side (t1) is HashPartitioning (bucketed files). + val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2")) + val plan1 = join1.queryExecution.executedPlan + assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) + val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.size == 1) + assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]) + val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection] + assert(p.partitionings.size == 4) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1"), t1("j1")), + Seq(t1("i1"), df2("j2")), + Seq(df2("i2"), t1("j1")), + Seq(df2("i2"), df2("j2"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) + } + + // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning + // is maintained by checking no shuffle exchange is introduced. + val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) + val plan2 = join2.queryExecution.executedPlan + assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1) + assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty) + + // Validate the data with broadcast join off. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) + QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq) Review comment: `checkAnswer` supports comparing two dataframes. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org