Repository: spark Updated Branches: refs/heads/master 6ac57fd0d -> bcceab649
[SPARK-22489][SQL] Shouldn't change broadcast join buildSide if user clearly specified ## What changes were proposed in this pull request? How to reproduce: ```scala import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("table1") spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value").createTempView("table2") val bl = sql("SELECT /*+ MAPJOIN(t1) */ * FROM table1 t1 JOIN table2 t2 ON t1.key = t2.key").queryExecution.executedPlan println(bl.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide) ``` The result is `BuildRight`, but should be `BuildLeft`. This PR fix this issue. ## How was this patch tested? unit tests Author: Yuming Wang <wgy...@gmail.com> Closes #19714 from wangyum/SPARK-22489. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bcceab64 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bcceab64 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bcceab64 Branch: refs/heads/master Commit: bcceab649510a45f4c4b8e44b157c9987adff6f4 Parents: 6ac57fd Author: Yuming Wang <wgy...@gmail.com> Authored: Thu Nov 30 15:36:26 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Thu Nov 30 15:36:26 2017 -0800 ---------------------------------------------------------------------- docs/sql-programming-guide.md | 58 ++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 67 ++++++++++++++----- .../execution/joins/BroadcastJoinSuite.scala | 69 +++++++++++++++++++- 3 files changed, 177 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/bcceab64/docs/sql-programming-guide.md ---------------------------------------------------------------------- diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 983770d..a1b9c3b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1492,6 +1492,64 @@ that these options will be deprecated in future release as more optimizations ar </tr> </table> +## Broadcast Hint for SQL Queries + +The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. +When both sides of a join are specified, Spark broadcasts the one having the lower statistics. +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +support BHJ. When the broadcast nested loop join is selected, we still respect the hint. + +<div class="codetabs"> + +<div data-lang="scala" markdown="1"> + +{% highlight scala %} +import org.apache.spark.sql.functions.broadcast +broadcast(spark.table("src")).join(spark.table("records"), "key").show() +{% endhighlight %} + +</div> + +<div data-lang="java" markdown="1"> + +{% highlight java %} +import static org.apache.spark.sql.functions.broadcast; +broadcast(spark.table("src")).join(spark.table("records"), "key").show(); +{% endhighlight %} + +</div> + +<div data-lang="python" markdown="1"> + +{% highlight python %} +from pyspark.sql.functions import broadcast +broadcast(spark.table("src")).join(spark.table("records"), "key").show() +{% endhighlight %} + +</div> + +<div data-lang="r" markdown="1"> + +{% highlight r %} +src <- sql("SELECT * FROM src") +records <- sql("SELECT * FROM records") +head(join(broadcast(src), records, src$key == records$key)) +{% endhighlight %} + +</div> + +<div data-lang="sql" markdown="1"> + +{% highlight sql %} +-- We accept BROADCAST, BROADCASTJOIN and MAPJOIN for broadcast hint +SELECT /*+ BROADCAST(r) */ * FROM records r JOIN src s ON r.key = s.key +{% endhighlight %} + +</div> +</div> + # Distributed SQL Engine Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. http://git-wip-us.apache.org/repos/asf/spark/blob/bcceab64/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 19b858f..1fe3cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery @@ -91,12 +91,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * predicates can be evaluated by matching join keys. If found, Join implementations are chosen * with the following precedence: * - * - Broadcast: if one side of the join has an estimated physical size that is smaller than the - * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold - * or if that side has an explicit broadcast hint (e.g. the user applied the - * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side - * of the join will be broadcasted and the other side will be streamed, with no shuffling - * performed. If both sides of the join are eligible to be broadcasted then the + * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the + * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame). + * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller + * estimated physical size. If neither one of the sides has the broadcast hint, + * we only broadcast the join side if its estimated physical size that is smaller than + * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold. * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. * - Sort merge: if the matching join keys are sortable. @@ -112,9 +112,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats.hints.broadcast || - (plan.stats.sizeInBytes >= 0 && - plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold } /** @@ -149,11 +147,46 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => false } + private def broadcastSide( + canBuildLeft: Boolean, + canBuildRight: Boolean, + left: LogicalPlan, + right: LogicalPlan): BuildSide = { + + def smallerSide = + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + + val buildRight = canBuildRight && right.stats.hints.broadcast + val buildLeft = canBuildLeft && left.stats.hints.broadcast + + if (buildRight && buildLeft) { + // Broadcast smaller side base on its estimated physical size + // if both sides have broadcast hint + smallerSide + } else if (buildRight) { + BuildRight + } else if (buildLeft) { + BuildLeft + } else if (canBuildRight && canBuildLeft) { + // for the last default broadcast nested loop join + smallerSide + } else { + throw new AnalysisException("Can not decide which side to broadcast for this join") + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if (canBuildRight(joinType) && right.stats.hints.broadcast) || + (canBuildLeft(joinType) && left.stats.hints.broadcast) => + val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + Seq(joins.BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildRight(joinType) && canBroadcast(right) => Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) @@ -190,6 +223,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Pick BroadcastNestedLoopJoin if one side could be broadcasted case j @ logical.Join(left, right, joinType, condition) + if (canBuildRight(joinType) && right.stats.hints.broadcast) || + (canBuildLeft(joinType) && left.stats.hints.broadcast) => + val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + + case j @ logical.Join(left, right, joinType, condition) if canBuildRight(joinType) && canBroadcast(right) => joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil @@ -203,12 +243,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { - BuildRight - } else { - BuildLeft - } + val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/bcceab64/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a0fad86..67e2cdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -223,4 +223,71 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) } + + test("Shouldn't change broadcast join buildSide if user clearly specified") { + def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { + val executedPlan = sql(sqlStr).queryExecution.executedPlan + executedPlan match { + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case w: WholeStageCodegenExec => + assert(w.children.head.getClass.getSimpleName === joinMethod) + assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) + } + } + + withTempView("t1", "t2") { + spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") + spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") + .createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + val bh = BroadcastHashJoinExec.toString + val bl = BroadcastNestedLoopJoinExec.toString + + // INNER JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // LEFT JOIN => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight) + // RIGHT JOIN => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // INNER JOIN && broadcast(t1) => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // INNER JOIN && broadcast(t2) => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight) + + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // INNER JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) + // FULL JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 FULL JOIN t2", bl, BuildLeft) + // LEFT JOIN => BuildRight + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildRight) + // RIGHT JOIN => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2", bl, BuildLeft) + // INNER JOIN && broadcast(t1) => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2", bl, BuildLeft) + // INNER JOIN && broadcast(t2) => BuildRight + assertJoinBuildSide("SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2", bl, BuildRight) + // FULL OUTER && broadcast(t1) => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + // FULL OUTER && broadcast(t2) => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t2) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildRight) + // FULL OUTER && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org