This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 43da473 [SPARK-27225][SQL] Implement join strategy hints
43da473 is described below
commit 43da473c1c49a99a038f558c64b5b34eb2fec764
Author: maryannxue <[email protected]>
AuthorDate: Fri Apr 12 00:14:37 2019 +0800
[SPARK-27225][SQL] Implement join strategy hints
## What changes were proposed in this pull request?
This PR extends the existing BROADCAST join hint (for both broadcast-hash
join and broadcast-nested-loop join) by implementing other join strategy hints
corresponding to the rest of Spark's existing join strategies: shuffle-hash,
sort-merge, cartesian-product. The hint names: SHUFFLE_MERGE, SHUFFLE_HASH,
SHUFFLE_REPLICATE_NL are partly different from the code names in order to make
them clearer to users and reflect the actual algorithms better.
The hinted strategy will be used for the join with which it is associated
if it is applicable/doable.
Conflict resolving rules in case of multiple hints:
1. Conflicts within either side of the join: take the first strategy hint
specified in the query, or the top hint node in Dataset. For example, in
"select /*+ merge(t1) */ /*+ broadcast(t1) */ k1, v2 from t1 join t2 on t1.k1 =
t2.k2", take "merge(t1)"; in
```df1.hint("merge").hint("shuffle_hash").join(df2)```, take "shuffle_hash".
This is a general hint conflict resolving strategy, not specific to join
strategy hint.
2. Conflicts between two sides of the join:
a) In case of different strategy hints, hints are prioritized as
```BROADCAST``` over ```SHUFFLE_MERGE``` over ```SHUFFLE_HASH``` over
```SHUFFLE_REPLICATE_NL```.
b) In case of same strategy hints but conflicts in build side, choose the
build side based on join type and size.
## How was this patch tested?
Added new UTs.
Closes #24164 from maryannxue/join-hints.
Lead-authored-by: maryannxue <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
docs/sql-performance-tuning.md | 24 +-
.../spark/sql/catalyst/analysis/Analyzer.scala | 2 +-
.../spark/sql/catalyst/analysis/ResolveHints.scala | 92 +++--
.../spark/sql/catalyst/catalog/interface.scala | 2 +-
.../catalyst/optimizer/EliminateResolvedHint.scala | 52 ++-
.../spark/sql/catalyst/plans/logical/hints.scala | 95 ++++-
.../sql/catalyst/analysis/ResolveHintsSuite.scala | 47 ++-
.../spark/sql/execution/SparkStrategies.scala | 366 +++++++++++--------
.../scala/org/apache/spark/sql/functions.scala | 4 +-
.../org/apache/spark/sql/CachedTableSuite.scala | 4 +-
.../scala/org/apache/spark/sql/JoinHintSuite.scala | 396 +++++++++++++++++++--
.../sql/execution/joins/BroadcastJoinSuite.scala | 5 +-
12 files changed, 837 insertions(+), 252 deletions(-)
diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md
index 6856974..2a1edda 100644
--- a/docs/sql-performance-tuning.md
+++ b/docs/sql-performance-tuning.md
@@ -107,14 +107,22 @@ that these options will be deprecated in future release
as more optimizations ar
</tr>
</table>
-## Broadcast Hint for SQL Queries
-
-The `BROADCAST` hint guides Spark to broadcast each specified table when
joining them with another table or view.
-When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is
preferred,
-even if the statistics is above the configuration
`spark.sql.autoBroadcastJoinThreshold`.
-When both sides of a join are specified, Spark broadcasts the one having the
lower statistics.
-Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g.
full outer join)
-support BHJ. When the broadcast nested loop join is selected, we still respect
the hint.
+## Join Strategy Hints for SQL Queries
+
+The join strategy hints, namely `BROADCAST`, `MERGE`, `SHUFFLE_HASH` and
`SHUFFLE_REPLICATE_NL`,
+instruct Spark to use the hinted strategy on each specified relation when
joining them with another
+relation. For example, when the `BROADCAST` hint is used on table 't1',
broadcast join (either
+broadcast hash join or broadcast nested loop join depending on whether there
is any equi-join key)
+with 't1' as the build side will be prioritized by Spark even if the size of
table 't1' suggested
+by the statistics is above the configuration
`spark.sql.autoBroadcastJoinThreshold`.
+
+When different join strategy hints are specified on both sides of a join,
Spark prioritizes the
+`BROADCAST` hint over the `MERGE` hint over the `SHUFFLE_HASH` hint over the
`SHUFFLE_REPLICATE_NL`
+hint. When both sides are specified with the `BROADCAST` hint or the
`SHUFFLE_HASH` hint, Spark will
+pick the build side based on the join type and the sizes of the relations.
+
+Note that there is no guarantee that Spark will choose the join strategy
specified in the hint since
+a specific strategy may not support all join types.
<div class="codetabs">
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 01e40e6..02d83e7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -153,7 +153,7 @@ class Analyzer(
lazy val batches: Seq[Batch] = Seq(
Batch("Hints", fixedPoint,
- new ResolveHints.ResolveBroadcastHints(conf),
+ new ResolveHints.ResolveJoinStrategyHints(conf),
ResolveHints.ResolveCoalesceHints,
ResolveHints.RemoveAllHints),
Batch("Simple Sanity Check", Once,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index dbd4ed8..9440a3f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
+import scala.collection.mutable
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.IntegerLiteral
import org.apache.spark.sql.catalyst.plans.logical._
@@ -28,45 +30,66 @@ import org.apache.spark.sql.internal.SQLConf
/**
- * Collection of rules related to hints. The only hint currently available is
broadcast join hint.
+ * Collection of rules related to hints. The only hint currently available is
join strategy hint.
*
* Note that this is separately into two rules because in the future we might
introduce new hint
- * rules that have different ordering requirements from broadcast.
+ * rules that have different ordering requirements from join strategies.
*/
object ResolveHints {
/**
- * For broadcast hint, we accept "BROADCAST", "BROADCASTJOIN", and
"MAPJOIN", and a sequence of
- * relation aliases can be specified in the hint. A broadcast hint plan node
will be inserted
- * on top of any relation (that is not aliased differently), subquery, or
common table expression
- * that match the specified name.
+ * The list of allowed join strategy hints is defined in
[[JoinStrategyHint.strategies]], and a
+ * sequence of relation aliases can be specified with a join strategy hint,
e.g., "MERGE(a, c)",
+ * "BROADCAST(a)". A join strategy hint plan node will be inserted on top of
any relation (that
+ * is not aliased differently), subquery, or common table expression that
match the specified
+ * name.
*
* The hint resolution works by recursively traversing down the query plan
to find a relation or
- * subquery that matches one of the specified broadcast aliases. The
traversal does not go past
- * beyond any existing broadcast hints, subquery aliases.
+ * subquery that matches one of the specified relation aliases. The
traversal does not go past
+ * beyond any view reference, with clause or subquery alias.
*
* This rule must happen before common table expressions.
*/
- class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] {
- private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN",
"MAPJOIN")
+ class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] {
+ private val STRATEGY_HINT_NAMES =
JoinStrategyHint.strategies.flatMap(_.hintAliases)
def resolver: Resolver = conf.resolver
- private def applyBroadcastHint(plan: LogicalPlan, toBroadcast:
Set[String]): LogicalPlan = {
+ private def createHintInfo(hintName: String): HintInfo = {
+ HintInfo(strategy =
+ JoinStrategyHint.strategies.find(
+ _.hintAliases.map(
+
_.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
+ }
+
+ private def applyJoinStrategyHint(
+ plan: LogicalPlan,
+ relations: mutable.HashSet[String],
+ hintName: String): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true
val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
- case u: UnresolvedRelation if toBroadcast.exists(resolver(_,
u.tableIdentifier.table)) =>
- ResolvedHint(plan, HintInfo(broadcast = true))
- case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
- ResolvedHint(plan, HintInfo(broadcast = true))
+ case ResolvedHint(u: UnresolvedRelation, hint)
+ if relations.exists(resolver(_, u.tableIdentifier.table)) =>
+ relations.remove(u.tableIdentifier.table)
+ ResolvedHint(u, createHintInfo(hintName).merge(hint,
handleOverriddenHintInfo))
+ case ResolvedHint(r: SubqueryAlias, hint)
+ if relations.exists(resolver(_, r.alias)) =>
+ relations.remove(r.alias)
+ ResolvedHint(r, createHintInfo(hintName).merge(hint,
handleOverriddenHintInfo))
+
+ case u: UnresolvedRelation if relations.exists(resolver(_,
u.tableIdentifier.table)) =>
+ relations.remove(u.tableIdentifier.table)
+ ResolvedHint(plan, createHintInfo(hintName))
+ case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
+ relations.remove(r.alias)
+ ResolvedHint(plan, createHintInfo(hintName))
case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
// Don't traverse down these nodes.
- // For an existing broadcast hint, there is no point going down
(if we do, we either
- // won't change the structure, or will introduce another broadcast
hint that is useless.
+ // For an existing strategy hint, there is no chance for a match
from this point down.
// The rest (view, with, subquery) indicates different scopes that
we shouldn't traverse
// down. Note that technically when this rule is executed, we
haven't completed view
// resolution yet and as a result the view part should be
deadcode. I'm leaving it here
@@ -80,25 +103,38 @@ object ResolveHints {
}
if ((plan fastEquals newNode) && recurse) {
- newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast))
+ newNode.mapChildren(child => applyJoinStrategyHint(child, relations,
hintName))
} else {
newNode
}
}
+ private def handleOverriddenHintInfo(hint: HintInfo): Unit = {
+ logWarning(s"Join hint $hint is overridden by another hint and will not
take effect.")
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
- case h: UnresolvedHint if
BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
+ case h: UnresolvedHint if
STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
- // If there is no table alias specified, turn the entire subtree
into a BroadcastHint.
- ResolvedHint(h.child, HintInfo(broadcast = true))
+ // If there is no table alias specified, apply the hint on the
entire subtree.
+ ResolvedHint(h.child, createHintInfo(h.name))
} else {
- // Otherwise, find within the subtree query plans that should be
broadcasted.
- applyBroadcastHint(h.child, h.parameters.map {
+ // Otherwise, find within the subtree query plans to apply the hint.
+ val relationNames = h.parameters.map {
case tableName: String => tableName
case tableId: UnresolvedAttribute => tableId.name
- case unsupported => throw new AnalysisException("Broadcast hint
parameter should be " +
- s"an identifier or string but was $unsupported
(${unsupported.getClass}")
- }.toSet)
+ case unsupported => throw new AnalysisException("Join strategy
hint parameter " +
+ s"should be an identifier or string but was $unsupported
(${unsupported.getClass}")
+ }
+ val relationNameSet = new mutable.HashSet[String]
+ relationNames.foreach(relationNameSet.add)
+
+ val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name)
+ relationNameSet.foreach { n =>
+ logWarning(s"Count not find relation '$n' for join strategy hint "
+
+ s"'${h.name}${relationNames.mkString("(", ", ", ")")}'.")
+ }
+ applied
}
}
}
@@ -135,7 +171,9 @@ object ResolveHints {
*/
object RemoveAllHints extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
- case h: UnresolvedHint => h.child
+ case h: UnresolvedHint =>
+ logWarning(s"Unrecognized hint: ${h.name}${h.parameters.mkString("(",
", ", ")")}")
+ h.child
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 6006637..2d64672 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -374,7 +374,7 @@ object CatalogTable {
/**
* This class of statistics is used in [[CatalogTable]] to interact with
metastore.
* We define this new class instead of directly using [[Statistics]] here
because there are no
- * concepts of attributes or broadcast hint in catalog.
+ * concepts of attributes in catalog.
*/
case class CatalogStatistics(
sizeInBytes: BigInt,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
index a136f04..5586690 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
@@ -30,30 +30,58 @@ object EliminateResolvedHint extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val pulledUp = plan transformUp {
case j: Join =>
- val leftHint = mergeHints(collectHints(j.left))
- val rightHint = mergeHints(collectHints(j.right))
- j.copy(hint = JoinHint(leftHint, rightHint))
+ val (newLeft, leftHints) = extractHintsFromPlan(j.left)
+ val (newRight, rightHints) = extractHintsFromPlan(j.right)
+ val newJoinHint = JoinHint(mergeHints(leftHints),
mergeHints(rightHints))
+ j.copy(left = newLeft, right = newRight, hint = newJoinHint)
}
pulledUp.transformUp {
- case h: ResolvedHint => h.child
+ case h: ResolvedHint =>
+ handleInvalidHintInfo(h.hints)
+ h.child
}
}
+ /**
+ * Combine a list of [[HintInfo]]s into one [[HintInfo]].
+ */
private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = {
- hints.reduceOption((h1, h2) => HintInfo(
- broadcast = h1.broadcast || h2.broadcast))
+ hints.reduceOption((h1, h2) => h1.merge(h2, handleOverriddenHintInfo))
}
- private def collectHints(plan: LogicalPlan): Seq[HintInfo] = {
+ /**
+ * Extract all hints from the plan, returning a list of extracted hints and
the transformed plan
+ * with [[ResolvedHint]] nodes removed. The returned hint list comes in
top-down order.
+ * Note that hints can only be extracted from under certain nodes. Those
that cannot be extracted
+ * in this method will be cleaned up later by this rule, and may emit
warnings depending on the
+ * configurations.
+ */
+ private def extractHintsFromPlan(plan: LogicalPlan): (LogicalPlan,
Seq[HintInfo]) = {
plan match {
- case h: ResolvedHint => collectHints(h.child) :+ h.hints
- case u: UnaryNode => collectHints(u.child)
+ case h: ResolvedHint =>
+ val (plan, hints) = extractHintsFromPlan(h.child)
+ (plan, h.hints +: hints)
+ case u: UnaryNode =>
+ val (plan, hints) = extractHintsFromPlan(u.child)
+ (u.withNewChildren(Seq(plan)), hints)
// TODO revisit this logic:
// except and intersect are semi/anti-joins which won't return more data
then
// their left argument, so the broadcast hint should be propagated here
- case i: Intersect => collectHints(i.left)
- case e: Except => collectHints(e.left)
- case _ => Seq.empty
+ case i: Intersect =>
+ val (plan, hints) = extractHintsFromPlan(i.left)
+ (i.copy(left = plan), hints)
+ case e: Except =>
+ val (plan, hints) = extractHintsFromPlan(e.left)
+ (e.copy(left = plan), hints)
+ case p: LogicalPlan => (p, Seq.empty)
}
}
+
+ private def handleInvalidHintInfo(hint: HintInfo): Unit = {
+ logWarning(s"A join hint $hint is specified but it is not part of a join
relation.")
+ }
+
+ private def handleOverriddenHintInfo(hint: HintInfo): Unit = {
+ logWarning(s"Join hint $hint is overridden by another hint and will not
take effect.")
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index b2ba725..870dd87 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -66,17 +66,94 @@ object JoinHint {
/**
* The hint attributes to be applied on a specific node.
*
- * @param broadcast If set to true, it indicates that the broadcast hash join
is the preferred join
- * strategy and the node with this hint is preferred to be
the build side.
+ * @param strategy The preferred join strategy.
*/
-case class HintInfo(broadcast: Boolean = false) {
+case class HintInfo(strategy: Option[JoinStrategyHint] = None) {
- override def toString: String = {
- val hints = scala.collection.mutable.ArrayBuffer.empty[String]
- if (broadcast) {
- hints += "broadcast"
+ /**
+ * Combine this [[HintInfo]] with another [[HintInfo]] and return the new
[[HintInfo]].
+ * @param other the other [[HintInfo]]
+ * @param hintOverriddenCallback a callback to notify if any [[HintInfo]]
has been overridden
+ * in this merge.
+ *
+ * Currently, for join strategy hints, the new [[HintInfo]] will contain the
strategy in this
+ * [[HintInfo]] if defined, otherwise the strategy in the other
[[HintInfo]]. The
+ * `hintOverriddenCallback` will be called if this [[HintInfo]] and the
other [[HintInfo]]
+ * both have a strategy defined but the join strategies are different.
+ */
+ def merge(other: HintInfo, hintOverriddenCallback: HintInfo => Unit):
HintInfo = {
+ if (this.strategy.isDefined &&
+ other.strategy.isDefined &&
+ this.strategy.get != other.strategy.get) {
+ hintOverriddenCallback(other)
}
-
- if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")")
+ HintInfo(strategy = this.strategy.orElse(other.strategy))
}
+
+ override def toString: String = strategy.map(s =>
s"(strategy=$s)").getOrElse("none")
+}
+
+sealed abstract class JoinStrategyHint {
+
+ def displayName: String
+ def hintAliases: Set[String]
+
+ override def toString: String = displayName
+}
+
+/**
+ * The enumeration of join strategy hints.
+ *
+ * The hinted strategy will be used for the join with which it is associated
if doable. In case
+ * of contradicting strategy hints specified for each side of the join, hints
are prioritized as
+ * BROADCAST over SHUFFLE_MERGE over SHUFFLE_HASH over SHUFFLE_REPLICATE_NL.
+ */
+object JoinStrategyHint {
+
+ val strategies: Set[JoinStrategyHint] = Set(
+ BROADCAST,
+ SHUFFLE_MERGE,
+ SHUFFLE_HASH,
+ SHUFFLE_REPLICATE_NL)
+}
+
+/**
+ * The hint for broadcast hash join or broadcast nested loop join, depending
on the availability of
+ * equi-join keys.
+ */
+case object BROADCAST extends JoinStrategyHint {
+ override def displayName: String = "broadcast"
+ override def hintAliases: Set[String] = Set(
+ "BROADCAST",
+ "BROADCASTJOIN",
+ "MAPJOIN")
+}
+
+/**
+ * The hint for shuffle sort merge join.
+ */
+case object SHUFFLE_MERGE extends JoinStrategyHint {
+ override def displayName: String = "merge"
+ override def hintAliases: Set[String] = Set(
+ "SHUFFLE_MERGE",
+ "MERGE",
+ "MERGEJOIN")
+}
+
+/**
+ * The hint for shuffle hash join.
+ */
+case object SHUFFLE_HASH extends JoinStrategyHint {
+ override def displayName: String = "shuffle_hash"
+ override def hintAliases: Set[String] = Set(
+ "SHUFFLE_HASH")
+}
+
+/**
+ * The hint for shuffle-and-replicate nested loop join, a.k.a. cartesian
product join.
+ */
+case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint {
+ override def displayName: String = "shuffle_replicate_nl"
+ override def hintAliases: Set[String] = Set(
+ "SHUFFLE_REPLICATE_NL")
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index 563e8ad..474e58a 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.log4j.{AppenderSkeleton, Level}
+import org.apache.log4j.spi.LoggingEvent
+
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
@@ -27,6 +32,14 @@ import org.apache.spark.sql.catalyst.plans.logical._
class ResolveHintsSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
+ class MockAppender extends AppenderSkeleton {
+ val loggingEvents = new ArrayBuffer[LoggingEvent]()
+
+ override def append(loggingEvent: LoggingEvent): Unit =
loggingEvents.append(loggingEvent)
+ override def close(): Unit = {}
+ override def requiresLayout(): Boolean = false
+ }
+
test("invalid hints should be ignored") {
checkAnalysis(
UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"),
table("TaBlE")),
@@ -37,17 +50,17 @@ class ResolveHintsSuite extends AnalysisTest {
test("case-sensitive or insensitive parameters") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = true)
checkAnalysis(
@@ -59,28 +72,29 @@ class ResolveHintsSuite extends AnalysisTest {
test("multiple broadcast hint aliases") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"),
table("table").join(table("table2"))),
- Join(ResolvedHint(testRelation, HintInfo(broadcast = true)),
- ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None,
JoinHint.NONE),
+ Join(ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
+ ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))),
+ Inner, None, JoinHint.NONE),
caseSensitive = false)
}
test("do not traverse past existing broadcast hints") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"),
- ResolvedHint(table("table").where('a > 1), HintInfo(broadcast =
true))),
- ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast =
true)).analyze,
+ ResolvedHint(table("table").where('a > 1), HintInfo(strategy =
Some(BROADCAST)))),
+ ResolvedHint(testRelation.where('a > 1), HintInfo(strategy =
Some(BROADCAST))).analyze,
caseSensitive = false)
}
test("should work for subqueries") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"),
table("table").as("tableAlias")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"),
table("table").subquery('tableAlias)),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
// Negative case: if the alias doesn't match, don't match the original
table name.
@@ -105,7 +119,7 @@ class ResolveHintsSuite extends AnalysisTest {
|SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable
""".stripMargin
),
- ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(broadcast =
true))
+ ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(strategy =
Some(BROADCAST)))
.select('a).analyze,
caseSensitive = false)
}
@@ -155,4 +169,17 @@ class ResolveHintsSuite extends AnalysisTest {
UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")),
Seq(errMsgRepa))
}
+
+ test("log warnings for invalid hints") {
+ val logAppender = new MockAppender()
+ withLogAppender(logAppender) {
+ checkAnalysis(
+ UnresolvedHint("unknown_hint", Seq("TaBlE"), table("TaBlE")),
+ testRelation,
+ caseSensitive = false)
+ }
+ assert(logAppender.loggingEvents.exists(
+ e => e.getLevel == Level.WARN &&
+ e.getRenderedMessage.contains("Unrecognized hint: unknown_hint")))
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index bf38189..efd05a3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -90,61 +90,35 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
}
/**
- * Select the proper physical plan for join based on joining keys and size
of logical plan.
- *
- * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at
least some of the
- * predicates can be evaluated by matching join keys. If found, join
implementations are chosen
- * with the following precedence:
+ * Select the proper physical plan for join based on join strategy hints,
the availability of
+ * equi-join keys and the sizes of joining relations. Below are the existing
join strategies,
+ * their characteristics and their limitations.
*
* - Broadcast hash join (BHJ):
- * BHJ is not supported for full outer join. For right outer join, we
only can broadcast the
- * left side. For left outer, left semi, left anti and the internal join
type ExistenceJoin,
- * we only can broadcast the right side. For inner like join, we can
broadcast both sides.
- * Normally, BHJ can perform faster than the other join algorithms when
the broadcast side is
- * small. However, broadcasting tables is a network-intensive operation.
It could cause OOM
- * or perform worse than the other join algorithms, especially when the
build/broadcast side
- * is big.
- *
- * For the supported cases, users can specify the broadcast hint (e.g.
the user applied the
- * [[org.apache.spark.sql.functions.broadcast()]] function to a
DataFrame) and session-based
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether
BHJ is used and
- * which join side is broadcast.
- *
- * 1) Broadcast the join side with the broadcast hint, even if the size
is larger than
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint
(only when the type
- * is inner like join), the side with a smaller estimated physical size
will be broadcast.
- * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and
broadcast the side
- * whose estimated physical size is smaller than the threshold. If both
sides are below the
- * threshold, broadcast the smaller side. If neither is smaller, BHJ is
not used.
- *
- * - Shuffle hash join: if the average size of a single partition is small
enough to build a hash
- * table.
- *
- * - Sort merge: if the matching join keys are sortable.
- *
- * If there is no joining keys, Join implementations are chosen with the
following precedence:
- * - BroadcastNestedLoopJoin (BNLJ):
- * BNLJ supports all the join types but the impl is OPTIMIZED for the
following scenarios:
- * For right outer join, the left side is broadcast. For left outer,
left semi, left anti
- * and the internal join type ExistenceJoin, the right side is
broadcast. For inner like
- * joins, either side is broadcast.
+ * Only supported for equi-joins, while the join keys do not need to be
sortable.
+ * Supported for all join types except full outer joins.
+ * BHJ usually performs faster than the other join algorithms when the
broadcast side is
+ * small. However, broadcasting tables is a network-intensive operation
and it could cause
+ * OOM or perform badly in some cases, especially when the
build/broadcast side is big.
*
- * Like BHJ, users still can specify the broadcast hint and session-based
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which
side is broadcast.
+ * - Shuffle hash join:
+ * Only supported for equi-joins, while the join keys do not need to be
sortable.
+ * Supported for all join types except full outer joins.
*
- * 1) Broadcast the join side with the broadcast hint, even if the size
is larger than
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint
(i.e., just for
- * inner-like join), the side with a smaller estimated physical size
will be broadcast.
- * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and
broadcast the side
- * whose estimated physical size is smaller than the threshold. If both
sides are below the
- * threshold, broadcast the smaller side. If neither is smaller, BNLJ is
not used.
+ * - Shuffle sort merge join (SMJ):
+ * Only supported for equi-joins and the join keys have to be sortable.
+ * Supported for all join types.
*
- * - CartesianProduct: for inner like join, CartesianProduct is the fallback
option.
+ * - Broadcast nested loop join (BNLJ):
+ * Supports both equi-joins and non-equi-joins.
+ * Supports all the join types, but the implementation is optimized for:
+ * 1) broadcasting the left side in a right outer join;
+ * 2) broadcasting the right side in a left outer, left semi, left
anti or existence join;
+ * 3) broadcasting either side in an inner-like join.
*
- * - BroadcastNestedLoopJoin (BNLJ):
- * For the other join types, BNLJ is the fallback option. Here, we just
pick the broadcast
- * side with the broadcast hint. If neither side has a hint, we
broadcast the side with
- * the smaller estimated physical size.
+ * - Shuffle-and-replicate nested loop join (a.k.a. cartesian product join):
+ * Supports both equi-joins and non-equi-joins.
+ * Supports only inner like joins.
*/
object JoinSelection extends Strategy with PredicateHelper {
@@ -186,126 +160,218 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case _ => false
}
- private def broadcastSide(
- canBuildLeft: Boolean,
- canBuildRight: Boolean,
+ private def getBuildSide(
+ wantToBuildLeft: Boolean,
+ wantToBuildRight: Boolean,
left: LogicalPlan,
- right: LogicalPlan): BuildSide = {
-
- def smallerSide =
- if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else
BuildLeft
-
- if (canBuildRight && canBuildLeft) {
- // Broadcast smaller side base on its estimated physical size
- // if both sides have broadcast hint
- smallerSide
- } else if (canBuildRight) {
- BuildRight
- } else if (canBuildLeft) {
- BuildLeft
+ right: LogicalPlan): Option[BuildSide] = {
+ if (wantToBuildLeft && wantToBuildRight) {
+ // returns the smaller side base on its estimated physical size, if we
want to build the
+ // both sides.
+ Some(getSmallerSide(left, right))
+ } else if (wantToBuildLeft) {
+ Some(BuildLeft)
+ } else if (wantToBuildRight) {
+ Some(BuildRight)
} else {
- // for the last default broadcast nested loop join
- smallerSide
+ None
}
}
- private def canBroadcastByHints(
- joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint:
JoinHint): Boolean = {
- val buildLeft = canBuildLeft(joinType) &&
hint.leftHint.exists(_.broadcast)
- val buildRight = canBuildRight(joinType) &&
hint.rightHint.exists(_.broadcast)
- buildLeft || buildRight
+ private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = {
+ if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else
BuildLeft
}
- private def broadcastSideByHints(
- joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint:
JoinHint): BuildSide = {
- val buildLeft = canBuildLeft(joinType) &&
hint.leftHint.exists(_.broadcast)
- val buildRight = canBuildRight(joinType) &&
hint.rightHint.exists(_.broadcast)
- broadcastSide(buildLeft, buildRight, left, right)
+ private def hintToBroadcastLeft(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(BROADCAST))
}
- private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan,
right: LogicalPlan)
- : Boolean = {
- val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
- val buildRight = canBuildRight(joinType) && canBroadcast(right)
- buildLeft || buildRight
+ private def hintToBroadcastRight(hint: JoinHint): Boolean = {
+ hint.rightHint.exists(_.strategy.contains(BROADCAST))
}
- private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan,
right: LogicalPlan)
- : BuildSide = {
- val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
- val buildRight = canBuildRight(joinType) && canBroadcast(right)
- broadcastSide(buildLeft, buildRight, left, right)
+ private def hintToShuffleHashLeft(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH))
+ }
+
+ private def hintToShuffleHashRight(hint: JoinHint): Boolean = {
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH))
+ }
+
+ private def hintToSortMergeJoin(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) ||
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))
+ }
+
+ private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) ||
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- // --- BroadcastHashJoin
--------------------------------------------------------------------
-
- // broadcast hints were specified
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, hint)
- if canBroadcastByHints(joinType, left, right, hint) =>
- val buildSide = broadcastSideByHints(joinType, left, right, hint)
- Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, joinType, buildSide, condition,
planLater(left), planLater(right)))
-
- // broadcast hints were not specified, so need to infer it from size and
configuration.
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
- if canBroadcastBySizes(joinType, left, right) =>
- val buildSide = broadcastSideBySizes(joinType, left, right)
- Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, joinType, buildSide, condition,
planLater(left), planLater(right)))
-
- // --- ShuffledHashJoin
---------------------------------------------------------------------
-
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
- if !conf.preferSortMergeJoin && canBuildRight(joinType) &&
canBuildLocalHashMap(right)
- && muchSmaller(right, left) ||
- !RowOrdering.isOrderable(leftKeys) =>
- Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, joinType, BuildRight, condition,
planLater(left), planLater(right)))
-
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
- if !conf.preferSortMergeJoin && canBuildLeft(joinType) &&
canBuildLocalHashMap(left)
- && muchSmaller(left, right) ||
- !RowOrdering.isOrderable(leftKeys) =>
- Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, joinType, BuildLeft, condition,
planLater(left), planLater(right)))
-
- // --- SortMergeJoin
------------------------------------------------------------
-
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
- if RowOrdering.isOrderable(leftKeys) =>
- joins.SortMergeJoinExec(
- leftKeys, rightKeys, joinType, condition, planLater(left),
planLater(right)) :: Nil
-
- // --- Without joining keys
------------------------------------------------------------
-
- // Pick BroadcastNestedLoopJoin if one side could be broadcast
- case j @ logical.Join(left, right, joinType, condition, hint)
- if canBroadcastByHints(joinType, left, right, hint) =>
- val buildSide = broadcastSideByHints(joinType, left, right, hint)
- joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition)
:: Nil
-
- case j @ logical.Join(left, right, joinType, condition, _)
- if canBroadcastBySizes(joinType, left, right) =>
- val buildSide = broadcastSideBySizes(joinType, left, right)
- joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition)
:: Nil
-
- // Pick CartesianProduct for InnerJoin
- case logical.Join(left, right, _: InnerLike, condition, _) =>
- joins.CartesianProductExec(planLater(left), planLater(right),
condition) :: Nil
+ // If it is an equi-join, we first look at the join hints w.r.t. the
following order:
+ // 1. broadcast hint: pick broadcast hash join if the join type is
supported. If both sides
+ // have the broadcast hints, choose the smaller side (based on
stats) to broadcast.
+ // 2. sort merge hint: pick sort merge join if join keys are sortable.
+ // 3. shuffle hash hint: We pick shuffle hash join if the join type is
supported. If both
+ // sides have the shuffle hash hints, choose the smaller side
(based on stats) as the
+ // build side.
+ // 4. shuffle replicate NL hint: pick cartesian product if join type
is inner like.
+ //
+ // If there is no hint or the hints are not applicable, we follow these
rules one by one:
+ // 1. Pick broadcast hash join if one side is small enough to
broadcast, and the join type
+ // is supported. If both sides are small, choose the smaller side
(based on stats)
+ // to broadcast.
+ // 2. Pick shuffle hash join if one side is small enough to build
local hash map, and is
+ // much smaller than the other side, and
`spark.sql.join.preferSortMergeJoin` is false.
+ // 3. Pick sort merge join if the join keys are sortable.
+ // 4. Pick cartesian product if join type is inner like.
+ // 5. Pick broadcast nested loop join as the final solution. It may
OOM but we don't have
+ // other choice.
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, hint) =>
+ def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) =
{
+ val wantToBuildLeft = canBuildLeft(joinType) && buildLeft
+ val wantToBuildRight = canBuildRight(joinType) && buildRight
+ getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map {
buildSide =>
+ Seq(joins.BroadcastHashJoinExec(
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide,
+ condition,
+ planLater(left),
+ planLater(right)))
+ }
+ }
+
+ def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = {
+ val wantToBuildLeft = canBuildLeft(joinType) && buildLeft
+ val wantToBuildRight = canBuildRight(joinType) && buildRight
+ getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map {
buildSide =>
+ Seq(joins.ShuffledHashJoinExec(
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide,
+ condition,
+ planLater(left),
+ planLater(right)))
+ }
+ }
+
+ def createSortMergeJoin() = {
+ if (RowOrdering.isOrderable(leftKeys)) {
+ Some(Seq(joins.SortMergeJoinExec(
+ leftKeys, rightKeys, joinType, condition, planLater(left),
planLater(right))))
+ } else {
+ None
+ }
+ }
+
+ def createCartesianProduct() = {
+ if (joinType.isInstanceOf[InnerLike]) {
+ Some(Seq(joins.CartesianProductExec(planLater(left),
planLater(right), condition)))
+ } else {
+ None
+ }
+ }
+
+ def createJoinWithoutHint() = {
+ createBroadcastHashJoin(canBroadcast(left), canBroadcast(right))
+ .orElse {
+ if (!conf.preferSortMergeJoin) {
+ createShuffleHashJoin(
+ canBuildLocalHashMap(left) && muchSmaller(left, right),
+ canBuildLocalHashMap(right) && muchSmaller(right, left))
+ } else {
+ None
+ }
+ }
+ .orElse(createSortMergeJoin())
+ .orElse(createCartesianProduct())
+ .getOrElse {
+ // This join could be very slow or OOM
+ val buildSide = getSmallerSide(left, right)
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType,
condition))
+ }
+ }
+ createBroadcastHashJoin(hintToBroadcastLeft(hint),
hintToBroadcastRight(hint))
+ .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else
None }
+ .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint),
hintToShuffleHashRight(hint)))
+ .orElse { if (hintToShuffleReplicateNL(hint))
createCartesianProduct() else None }
+ .getOrElse(createJoinWithoutHint())
+
+ // If it is not an equi-join, we first look at the join hints w.r.t. the
following order:
+ // 1. broadcast hint: pick broadcast nested loop join. If both sides
have the broadcast
+ // hints, choose the smaller side (based on stats) to broadcast.
+ // 2. shuffle replicate NL hint: pick cartesian product if join type
is inner like.
+ //
+ // If there is no hint or the hints are not applicable, we follow these
rules one by one:
+ // 1. Pick cartesian product if join type is inner like, and both
sides are too big to
+ // to broadcast.
+ // 2. Pick broadcast nested loop join. Pick the smaller side (based on
stats) to broadcast.
case logical.Join(left, right, joinType, condition, hint) =>
- val buildSide = broadcastSide(
- hint.leftHint.exists(_.broadcast),
hint.rightHint.exists(_.broadcast), left, right)
- // This join could be very slow or OOM
- joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition)
:: Nil
+ def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = {
+ getBuildSide(buildLeft, buildRight, left, right).map { buildSide =>
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType,
condition))
+ }
+ }
- // --- Cases where this strategy does not apply
---------------------------------------------
+ def createCartesianProduct() = {
+ if (joinType.isInstanceOf[InnerLike]) {
+ Some(Seq(joins.CartesianProductExec(planLater(left),
planLater(right), condition)))
+ } else {
+ None
+ }
+ }
+
+ def createJoinWithoutHint() = {
+ (if (!canBroadcast(left) && !canBroadcast(right))
createCartesianProduct() else None)
+ .getOrElse {
+ // This join could be very slow or OOM
+ val buildSide = getSmallerSide(left, right)
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType,
condition))
+ }
+ }
+
+ if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) {
+ createBroadcastNLJoin(hintToBroadcastLeft(hint),
hintToBroadcastRight(hint))
+ .orElse { if (hintToShuffleReplicateNL(hint))
createCartesianProduct() else None }
+ .getOrElse(createJoinWithoutHint())
+ } else {
+ val smallerSide = getSmallerSide(left, right)
+ val buildSide = if (canBuildLeft(joinType)) {
+ // For RIGHT JOIN, we may broadcast left side even if the hint
asks us to broadcast
+ // the right side. This is for history reasons.
+ if (hintToBroadcastLeft(hint) || canBroadcast(left)) {
+ BuildLeft
+ } else if (hintToBroadcastRight(hint)) {
+ BuildRight
+ } else {
+ smallerSide
+ }
+ } else {
+ // For LEFT JOIN, we may broadcast right side even if the hint
asks us to broadcast
+ // the left side. This is for history reasons.
+ if (hintToBroadcastRight(hint) || canBroadcast(right)) {
+ BuildRight
+ } else if (hintToBroadcastLeft(hint)) {
+ BuildLeft
+ } else {
+ smallerSide
+ }
+ }
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType, condition))
+ }
+
+ // --- Cases where this strategy does not apply
---------------------------------------------
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index bcb5783..7ac3ed5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star,
UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo,
ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction,
UserDefinedFunction}
import org.apache.spark.sql.internal.SQLConf
@@ -1045,7 +1045,7 @@ object functions {
*/
def broadcast[T](df: Dataset[T]): Dataset[T] = {
Dataset[T](df.sparkSession,
- ResolvedHint(df.logicalPlan, HintInfo(broadcast = true)))(df.exprEnc)
+ ResolvedHint(df.logicalPlan, HintInfo(strategy =
Some(BROADCAST))))(df.exprEnc)
}
/**
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 4d63390..92157d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join}
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -951,7 +951,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
with SharedSQLContext
case Join(_, _, _, _, hint) => hint
}
assert(hint.size == 1)
- assert(hint(0).leftHint.get.broadcast)
+ assert(hint(0).leftHint.get.strategy.contains(BROADCAST))
assert(hint(0).rightHint.isEmpty)
// Clean-up
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 67f0f1a..9c2dc0c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -17,8 +17,14 @@
package org.apache.spark.sql
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.log4j.{AppenderSkeleton, Level}
+import org.apache.log4j.spi.LoggingEvent
+
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@@ -30,6 +36,41 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
lazy val df2 = df.selectExpr("id as b1", "id as b2")
lazy val df3 = df.selectExpr("id as c1", "id as c2")
+ class MockAppender extends AppenderSkeleton {
+ val loggingEvents = new ArrayBuffer[LoggingEvent]()
+
+ override def append(loggingEvent: LoggingEvent): Unit =
loggingEvents.append(loggingEvent)
+ override def close(): Unit = {}
+ override def requiresLayout(): Boolean = false
+ }
+
+ def msgNoHintRelationFound(relation: String, hint: String): String =
+ s"Count not find relation '$relation' for join strategy hint '$hint'."
+
+ def msgNoJoinForJoinHint(strategy: String): String =
+ s"A join hint (strategy=$strategy) is specified but it is not part of a
join relation."
+
+ def msgJoinHintOverridden(strategy: String): String =
+ s"Join hint (strategy=$strategy) is overridden by another hint and will
not take effect."
+
+ def verifyJoinHintWithWarnings(
+ df: => DataFrame,
+ expectedHints: Seq[JoinHint],
+ warnings: Seq[String]): Unit = {
+ val logAppender = new MockAppender()
+ withLogAppender(logAppender) {
+ verifyJoinHint(df, expectedHints)
+ }
+ val warningMessages = logAppender.loggingEvents
+ .filter(_.getLevel == Level.WARN)
+ .map(_.getRenderedMessage)
+ .filter(_.contains("hint"))
+ assert(warningMessages.size == warnings.size)
+ warnings.foreach { w =>
+ assert(warningMessages.contains(w))
+ }
+ }
+
def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = {
val optimized = df.queryExecution.optimizedPlan
val joinHints = optimized collect {
@@ -43,14 +84,14 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
verifyJoinHint(
df.hint("broadcast").join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
verifyJoinHint(
df.join(df.hint("broadcast"), "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
)
}
@@ -59,18 +100,18 @@ class JoinHintSuite extends PlanTest with SharedSQLContext
{
df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"),
'a1 === 'c1),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
verifyJoinHint(
df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3,
'a1 === 'c1),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
}
@@ -89,13 +130,13 @@ class JoinHintSuite extends PlanTest with SharedSQLContext
{
|) b on a.a1 = b.b1
""".stripMargin),
JoinHint(
- Some(HintInfo(broadcast = true)),
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
}
@@ -112,9 +153,9 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
verifyJoinHint(
@@ -122,25 +163,25 @@ class JoinHintSuite extends PlanTest with
SharedSQLContext {
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint.NONE ::
JoinHint(
- Some(HintInfo(broadcast = true)),
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
None,
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
)
verifyJoinHint(
df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
.join(df3, 'b1 === 'c1 && 'a1 < 10),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint.NONE :: Nil
)
@@ -151,7 +192,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
.join(df, 'b1 === 'id),
JoinHint.NONE ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint.NONE :: Nil
)
@@ -164,7 +205,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
verifyJoinHint(
df.hint("broadcast").except(dfSub).join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint.NONE :: Nil
)
@@ -172,31 +213,112 @@ class JoinHintSuite extends PlanTest with
SharedSQLContext {
df.join(df.hint("broadcast").intersect(dfSub), "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint.NONE :: Nil
)
}
test("hint merge") {
- verifyJoinHint(
+ verifyJoinHintWithWarnings(
df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
- None) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ None) :: Nil,
+ Nil
)
- verifyJoinHint(
+ verifyJoinHintWithWarnings(
df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil,
+ Nil
+ )
+ verifyJoinHintWithWarnings(
+ df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df,
"id").hint("broadcast"),
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH))),
+ None) :: Nil,
+ msgJoinHintOverridden("merge") ::
+ msgNoJoinForJoinHint("broadcast") :: Nil
+ )
+ verifyJoinHintWithWarnings(
+ df.join(df.hint("broadcast").limit(2).hint("merge"), "id")
+ .hint("shuffle_hash")
+ .hint("shuffle_replicate_nl")
+ .join(df, "id"),
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))),
+ None) ::
+ JoinHint(
+ None,
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil,
+ msgJoinHintOverridden("broadcast") ::
+ msgJoinHintOverridden("shuffle_hash") :: Nil
)
}
+ test("hint merge - SQL") {
+ withTempView("a", "b", "c") {
+ df1.createOrReplaceTempView("a")
+ df2.createOrReplaceTempView("b")
+ df3.createOrReplaceTempView("c")
+ verifyJoinHintWithWarnings(
+ sql("select /*+ shuffle_hash merge(a, c) broadcast(a, b)*/ * from a,
b, c " +
+ "where a.a1 = b.b1 and b.b1 = c.c1"),
+ JoinHint(
+ None,
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) ::
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE))),
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil,
+ msgNoJoinForJoinHint("shuffle_hash") ::
+ msgJoinHintOverridden("broadcast") :: Nil
+ )
+ verifyJoinHintWithWarnings(
+ sql("select /*+ shuffle_hash(a, b) merge(b, d) broadcast(b)*/ * from
a, b, c " +
+ "where a.a1 = b.b1 and b.b1 = c.c1"),
+ JoinHint.NONE ::
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH))),
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil,
+ msgNoHintRelationFound("d", "merge(b, d)") ::
+ msgJoinHintOverridden("broadcast") ::
+ msgJoinHintOverridden("merge") :: Nil
+ )
+ verifyJoinHintWithWarnings(
+ sql(
+ """
+ |select /*+ broadcast(a, c) merge(a, d)*/ * from a
+ |join (
+ | select /*+ shuffle_hash(c) shuffle_replicate_nl(b, c)*/ * from b
+ | join c on b.b1 = c.c1
+ |) as d
+ |on a.a2 = d.b2
+ """.stripMargin),
+ JoinHint(
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) ::
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))),
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil,
+ msgNoHintRelationFound("c", "broadcast(a, c)") ::
+ msgJoinHintOverridden("merge") ::
+ msgJoinHintOverridden("shuffle_replicate_nl") :: Nil
+ )
+ }
+ }
+
test("nested hint") {
verifyJoinHint(
df.hint("broadcast").hint("broadcast").filter('id > 2).join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ None) :: Nil
+ )
+ verifyJoinHint(
+ df.hint("shuffle_hash").hint("broadcast").hint("merge").filter('id >
2).join(df, "id"),
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE))),
None) :: Nil
)
}
@@ -209,12 +331,230 @@ class JoinHintSuite extends PlanTest with
SharedSQLContext {
join.join(broadcasted, "id").join(broadcasted, "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil
)
}
}
+
+ def equiJoinQueryWithHint(hints: Seq[String], joinType: String = "INNER"):
String =
+ hints.map("/*+ " + _ + " */").mkString(
+ "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key = t2.key")
+
+ def nonEquiJoinQueryWithHint(hints: Seq[String], joinType: String =
"INNER"): String =
+ hints.map("/*+ " + _ + " */").mkString(
+ "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key > t2.key")
+
+ private def assertBroadcastHashJoin(df: DataFrame, buildSide: BuildSide):
Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val broadcastHashJoins = executedPlan.collect {
+ case b: BroadcastHashJoinExec => b
+ }
+ assert(broadcastHashJoins.size == 1)
+ assert(broadcastHashJoins.head.buildSide == buildSide)
+ }
+
+ private def assertBroadcastNLJoin(df: DataFrame, buildSide: BuildSide): Unit
= {
+ val executedPlan = df.queryExecution.executedPlan
+ val broadcastNLJoins = executedPlan.collect {
+ case b: BroadcastNestedLoopJoinExec => b
+ }
+ assert(broadcastNLJoins.size == 1)
+ assert(broadcastNLJoins.head.buildSide == buildSide)
+ }
+
+ private def assertShuffleHashJoin(df: DataFrame, buildSide: BuildSide): Unit
= {
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleHashJoins = executedPlan.collect {
+ case s: ShuffledHashJoinExec => s
+ }
+ assert(shuffleHashJoins.size == 1)
+ assert(shuffleHashJoins.head.buildSide == buildSide)
+ }
+
+ private def assertShuffleMergeJoin(df: DataFrame): Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleMergeJoins = executedPlan.collect {
+ case s: SortMergeJoinExec => s
+ }
+ assert(shuffleMergeJoins.size == 1)
+ }
+
+ private def assertShuffleReplicateNLJoin(df: DataFrame): Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleReplicateNLJoins = executedPlan.collect {
+ case c: CartesianProductExec => c
+ }
+ assert(shuffleReplicateNLJoins.size == 1)
+ }
+
+ test("join strategy hint - broadcast") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key",
"value").createTempView("t2")
+
+ val t1Size =
spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
+ val t2Size =
spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
+ assert(t1Size < t2Size)
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ // Broadcast hint specified on one side
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil)), BuildLeft)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("BROADCAST(t2)" :: Nil)), BuildRight)
+
+ // Determine build side based on the join type and child relation sizes
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil)), BuildLeft)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")),
BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "right")),
BuildLeft)
+
+ // Use broadcast-hash join if hinted "broadcast" and equi-join
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" ::
Nil)), BuildRight)
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t1, t2)" ::
Nil)), BuildLeft)
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" ::
"SHUFFLE_REPLICATE_NL(t2)" :: Nil)),
+ BuildLeft)
+
+ // Use broadcast-nl join if hinted "broadcast" and non-equi-join
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: "BROADCAST(t1)"
:: Nil)), BuildLeft)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("MERGE(t1)" :: "BROADCAST(t2)" ::
Nil)), BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" ::
"BROADCAST(t2)" :: Nil)),
+ BuildRight)
+
+ // Broadcast hint specified but not doable
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil, "left")))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: Nil, "right")))
+ }
+ }
+ }
+
+ test("join strategy hint - shuffle-merge") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key",
"value").createTempView("t2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key ->
Int.MaxValue.toString) {
+ // Shuffle-merge hint specified on one side
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_MERGE(t1)" :: Nil)))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("MERGEJOIN(t2)" :: Nil)))
+
+ // Shuffle-merge hint specified on both sides
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("MERGE(t1, t2)" :: Nil)))
+
+ // Shuffle-merge hint prioritized over shuffle-hash hint and
shuffle-replicate-nl hint
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "MERGE(t1)"
:: Nil, "left")))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("MERGE(t2)" :: "SHUFFLE_HASH(t1)" :: Nil,
"right")))
+
+ // Broadcast hint prioritized over shuffle-merge hint, but broadcast
hint is not applicable
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t2)" :: Nil,
"left")))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "MERGE(t1)" :: Nil,
"right")))
+
+ // Shuffle-merge hint specified but not doable
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")),
BuildRight)
+ }
+ }
+ }
+
+ test("join strategy hint - shuffle-hash") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key",
"value").createTempView("t2")
+
+ val t1Size =
spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
+ val t2Size =
spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
+ assert(t1Size < t2Size)
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key ->
Int.MaxValue.toString) {
+ // Shuffle-hash hint specified on one side
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: Nil)), BuildRight)
+
+ // Determine build side based on the join type and child relation sizes
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil)), BuildLeft)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "left")),
BuildRight)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "right")),
BuildLeft)
+
+ // Shuffle-hash hint prioritized over shuffle-replicate-nl hint
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" ::
"SHUFFLE_HASH(t1)" :: Nil)),
+ BuildLeft)
+
+ // Broadcast hint prioritized over shuffle-hash hint, but broadcast
hint is not applicable
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_HASH(t2)" ::
Nil, "left")),
+ BuildRight)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" ::
Nil, "right")),
+ BuildLeft)
+
+ // Shuffle-hash hint specified but not doable
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil, "left")),
BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft)
+ }
+ }
+ }
+
+ test("join strategy hint - shuffle-replicate-nl") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key",
"value").createTempView("t2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key ->
Int.MaxValue.toString) {
+ // Shuffle-replicate-nl hint specified on one side
+ assertShuffleReplicateNLJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" :: Nil)))
+ assertShuffleReplicateNLJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: Nil)))
+
+ // Shuffle-replicate-nl hint specified on both sides
+ assertShuffleReplicateNLJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil)))
+
+ // Shuffle-merge hint prioritized over shuffle-replicate-nl hint, but
shuffle-merge hint
+ // is not applicable
+ assertShuffleReplicateNLJoin(
+ sql(nonEquiJoinQueryWithHint("MERGE(t1)" ::
"SHUFFLE_REPLICATE_NL(t2)" :: Nil)))
+
+ // Shuffle-hash hint prioritized over shuffle-replicate-nl hint, but
shuffle-hash hint is
+ // not applicable
+ assertShuffleReplicateNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" ::
"SHUFFLE_REPLICATE_NL(t1)" :: Nil)))
+
+ // Shuffle-replicate-nl hint specified but not doable
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil,
"left")), BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil,
"right")), BuildLeft)
+ }
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index f238148..05c583c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast,
Literal, ShiftLeft}
+import org.apache.spark.sql.catalyst.plans.logical.BROADCAST
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
@@ -216,10 +217,10 @@ class BroadcastJoinSuite extends QueryTest with
SQLTestUtils {
val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id =
u.id").queryExecution
.optimizedPlan
- assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast)
+
assert(plan1.asInstanceOf[Join].hint.leftHint.get.strategy.contains(BROADCAST))
assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty)
assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty)
- assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast)
+
assert(plan2.asInstanceOf[Join].hint.rightHint.get.strategy.contains(BROADCAST))
assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty)
assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]