Repository: spark Updated Branches: refs/heads/master e0d7665ce -> 2c0fe818a
[SPARK-22445][SQL][FOLLOW-UP] Respect stream-side child's needCopyResult in BroadcastHashJoin ## What changes were proposed in this pull request? I found #19656 causes some bugs, for example, it changed the result set of `q6` in tpcds (I keep tracking TPCDS results daily [here](https://github.com/maropu/spark-tpcds-datagen/tree/master/reports/tests)): - w/o pr19658 ``` +-----+---+ |state|cnt| +-----+---+ | MA| 10| | AK| 10| | AZ| 11| | ME| 13| | VT| 14| | NV| 15| | NH| 16| | UT| 17| | NJ| 21| | MD| 22| | WY| 25| | NM| 26| | OR| 31| | WA| 36| | ND| 38| | ID| 39| | SC| 45| | WV| 50| | FL| 51| | OK| 53| | MT| 53| | CO| 57| | AR| 58| | NY| 58| | PA| 62| | AL| 63| | LA| 63| | SD| 70| | WI| 80| | null| 81| | MI| 82| | NC| 82| | MS| 83| | CA| 84| | MN| 85| | MO| 88| | IL| 95| | IA|102| | TN|102| | IN|103| | KY|104| | NE|113| | OH|114| | VA|130| | KS|139| | GA|168| | TX|216| +-----+---+ ``` - w/ pr19658 ``` +-----+---+ |state|cnt| +-----+---+ | RI| 14| | AK| 16| | FL| 20| | NJ| 21| | NM| 21| | NV| 22| | MA| 22| | MD| 22| | UT| 22| | AZ| 25| | SC| 28| | AL| 36| | MT| 36| | WA| 39| | ND| 41| | MI| 44| | AR| 45| | OR| 47| | OK| 52| | PA| 53| | LA| 55| | CO| 55| | NY| 64| | WV| 66| | SD| 72| | MS| 73| | NC| 79| | IN| 82| | null| 85| | ID| 88| | MN| 91| | WI| 95| | IL| 96| | MO| 97| | CA|109| | CA|109| | TN|114| | NE|115| | KY|128| | OH|131| | IA|156| | TX|160| | VA|182| | KS|211| | GA|230| +-----+---+ ``` This pr is to keep the original logic of `CodegenContext.copyResult` in `BroadcastHashJoinExec`. ## How was this patch tested? Existing tests Author: Takeshi Yamamuro <[email protected]> Closes #19781 from maropu/SPARK-22445-bugfix. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2c0fe818 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2c0fe818 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2c0fe818 Branch: refs/heads/master Commit: 2c0fe818a624cfdc76c752ec6bfe6a42e5680604 Parents: e0d7665 Author: Takeshi Yamamuro <[email protected]> Authored: Wed Nov 22 09:09:50 2017 +0100 Committer: Wenchen Fan <[email protected]> Committed: Wed Nov 22 09:09:50 2017 +0100 ---------------------------------------------------------------------- .../execution/joins/BroadcastHashJoinExec.scala | 15 ++++++----- .../scala/org/apache/spark/sql/JoinSuite.scala | 27 +++++++++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2c0fe818/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 837b852..c96ed6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -76,20 +76,23 @@ case class BroadcastHashJoinExec( streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } - override def needCopyResult: Boolean = joinType match { + private def multipleOutputForOneInput: Boolean = joinType match { case _: InnerLike | LeftOuter | RightOuter => // For inner and outer joins, one row from the streamed side may produce multiple result rows, - // if the build side has duplicated keys. Then we need to copy the result rows before putting - // them in a buffer, because these result rows share one UnsafeRow instance. Note that here - // we wait for the broadcast to be finished, which is a no-op because it's already finished - // when we wait it in `doProduce`. + // if the build side has duplicated keys. Note that here we wait for the broadcast to be + // finished, which is a no-op because it's already finished when we wait it in `doProduce`. !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique // Other joins types(semi, anti, existence) can at most produce one result row for one input - // row from the streamed side, so no need to copy the result rows. + // row from the streamed side. case _ => false } + // If the streaming side needs to copy result, this join plan needs to copy too. Otherwise, + // this join plan only needs to copy result if it may output multiple rows for one input. + override def needCopyResult: Boolean = + streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput + override def doProduce(ctx: CodegenContext): String = { streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } http://git-wip-us.apache.org/repos/asf/spark/blob/2c0fe818/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 226cc30..771e118 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.execution.SortExec +import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext { joinQueries.foreach(assertJoinOrdering) } + + test("SPARK-22445 Respect stream-side child's needCopyResult in BroadcastHashJoin") { + val df1 = Seq((2, 3), (2, 5), (2, 2), (3, 8), (2, 1)).toDF("k", "v1") + val df2 = Seq((2, 8), (3, 7), (3, 4), (1, 2)).toDF("k", "v2") + val df3 = Seq((1, 1), (3, 2), (4, 3), (5, 1)).toDF("k", "v3") + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.JOIN_REORDER_ENABLED.key -> "false") { + val df = df1.join(df2, "k").join(functions.broadcast(df3), "k") + val plan = df.queryExecution.sparkPlan + + // Check if `needCopyResult` in `BroadcastHashJoin` is correct when smj->bhj + val joins = new collection.mutable.ArrayBuffer[BinaryExecNode]() + plan.foreachUp { + case j: BroadcastHashJoinExec => joins += j + case j: SortMergeJoinExec => joins += j + case _ => + } + assert(joins.size == 2) + assert(joins(0).isInstanceOf[SortMergeJoinExec]) + assert(joins(1).isInstanceOf[BroadcastHashJoinExec]) + checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
