asfgit closed pull request #23036: [SPARK-26065][SQL] Change query hint from a
`LogicalPlan` to a field
URL: https://github.com/apache/spark/pull/23036
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 198645d875c47..2aa0f2117364c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -943,7 +943,7 @@ class Analyzer(
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
// To resolve duplicate expression IDs for Join and Intersect
- case j @ Join(left, right, _, _) if !j.duplicateResolved =>
+ case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
@@ -2249,13 +2249,14 @@ class Analyzer(
*/
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsUp {
- case j @ Join(left, right, UsingJoin(joinType, usingCols), _)
+ case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint)
if left.resolved && right.resolved && j.duplicateResolved =>
- commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
- case j @ Join(left, right, NaturalJoin(joinType), condition) if
j.resolvedExceptNatural =>
+ commonNaturalJoinProcessing(left, right, joinType, usingCols, None,
hint)
+ case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
+ if j.resolvedExceptNatural =>
// find common column names from both sides
val joinNames =
left.output.map(_.name).intersect(right.output.map(_.name))
- commonNaturalJoinProcessing(left, right, joinType, joinNames,
condition)
+ commonNaturalJoinProcessing(left, right, joinType, joinNames,
condition, hint)
}
}
@@ -2360,7 +2361,8 @@ class Analyzer(
right: LogicalPlan,
joinType: JoinType,
joinNames: Seq[String],
- condition: Option[Expression]) = {
+ condition: Option[Expression],
+ hint: JoinHint) = {
val leftKeys = joinNames.map { keyName =>
left.output.find(attr => resolver(attr.name, keyName)).getOrElse {
throw new AnalysisException(s"USING column `$keyName` cannot be
resolved on the left " +
@@ -2401,7 +2403,7 @@ class Analyzer(
sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
- Project(projectList, Join(left, right, joinType, newCondition))
+ Project(projectList, Join(left, right, joinType, newCondition, hint))
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c28a97839fe49..18c40b370cb5f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -172,7 +172,7 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis("Null-aware predicate sub-queries cannot be used in
nested " +
s"conditions: $condition")
- case j @ Join(_, _, _, Some(condition)) if condition.dataType !=
BooleanType =>
+ case j @ Join(_, _, _, Some(condition), _) if condition.dataType !=
BooleanType =>
failAnalysis(
s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.catalogString} is not a
boolean.")
@@ -609,7 +609,7 @@ trait CheckAnalysis extends PredicateHelper {
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
// Join can host correlated expressions.
- case j @ Join(left, right, joinType, _) =>
+ case j @ Join(left, right, joinType, _, _) =>
joinType match {
// Inner join, like Filter, can be anywhere.
case _: InnerLike =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
index 7a0aa08289efa..76733dd6dac3c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
@@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with
Logging {
*/
def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = {
plan match {
- case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) =>
+ case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) =>
(leftKeys ++ rightKeys).exists {
case a: AttributeReference =>
a.metadata.contains(EventTimeWatermark.delayKey)
case _ => false
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index cff4cee09427f..41ba6d34b5499 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -229,7 +229,7 @@ object UnsupportedOperationChecker {
throwError("dropDuplicates is not supported after aggregation on a "
+
"streaming DataFrame/Dataset")
- case Join(left, right, joinType, condition) =>
+ case Join(left, right, joinType, condition, _) =>
joinType match {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 151481c80ee96..846ee3b386527 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -325,7 +325,7 @@ package object dsl {
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
condition: Option[Expression] = None): LogicalPlan =
- Join(logicalPlan, otherPlan, joinType, condition)
+ Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE)
def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result:
Encoder](
otherPlan: LogicalPlan,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index 01634a9d852c6..743d3ce944fe2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{And, Attribute,
AttributeSet, Expression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType}
-import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join,
LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
@@ -31,6 +31,40 @@ import org.apache.spark.sql.internal.SQLConf
* Cost-based join reorder.
* We may have several join reorder algorithms in the future. This class is
the entry of these
* algorithms, and chooses which one to use.
+ *
+ * Note that join strategy hints, e.g. the broadcast hint, do not interfere
with the reordering.
+ * Such hints will be applied on the equivalent counterparts (i.e., join
between the same relations
+ * regardless of the join order) of the original nodes after reordering.
+ * For example, the plan before reordering is like:
+ *
+ * Join
+ * / \
+ * Hint1 t4
+ * /
+ * Join
+ * / \
+ * Join t3
+ * / \
+ * Hint2 t2
+ * /
+ * t1
+ *
+ * The original join order as illustrated above is "((t1 JOIN t2) JOIN t3)
JOIN t4", and after
+ * reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the
new plan will be like:
+ *
+ * Join
+ * / \
+ * Hint1 t4
+ * /
+ * Join
+ * / \
+ * Join t2
+ * / \
+ * t1 t3
+ *
+ * "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the
original hinted node,
+ * "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan
since there is no
+ * equivalent node to "t1 JOIN t2".
*/
object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
@@ -40,24 +74,30 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with
PredicateHelper {
if (!conf.cboEnabled || !conf.joinReorderEnabled) {
plan
} else {
+ // Use a map to track the hints on the join items.
+ val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
val result = plan transformDown {
// Start reordering with a joinable item, which is an InnerLike join
with conditions.
- case j @ Join(_, _, _: InnerLike, Some(cond)) =>
- reorder(j, j.output)
- case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond)))
+ case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
+ reorder(j, j.output, hintMap)
+ case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _))
if projectList.forall(_.isInstanceOf[Attribute]) =>
- reorder(p, p.output)
+ reorder(p, p.output, hintMap)
}
-
- // After reordering is finished, convert OrderedJoin back to Join
- result transformDown {
- case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond)
+ // After reordering is finished, convert OrderedJoin back to Join.
+ result transform {
+ case OrderedJoin(left, right, jt, cond) =>
+ val joinHint = JoinHint(hintMap.get(left.outputSet),
hintMap.get(right.outputSet))
+ Join(left, right, jt, cond, joinHint)
}
}
}
- private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan
= {
- val (items, conditions) = extractInnerJoins(plan)
+ private def reorder(
+ plan: LogicalPlan,
+ output: Seq[Attribute],
+ hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = {
+ val (items, conditions) = extractInnerJoins(plan, hintMap)
val result =
// Do reordering if the number of items is appropriate and join
conditions exist.
// We also need to check if costs of all items can be evaluated.
@@ -75,27 +115,31 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with
PredicateHelper {
* Extracts items of consecutive inner joins and join conditions.
* This method works for bushy trees and left/right deep trees.
*/
- private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan],
Set[Expression]) = {
+ private def extractInnerJoins(
+ plan: LogicalPlan,
+ hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan],
Set[Expression]) = {
plan match {
- case Join(left, right, _: InnerLike, Some(cond)) =>
- val (leftPlans, leftConditions) = extractInnerJoins(left)
- val (rightPlans, rightConditions) = extractInnerJoins(right)
+ case Join(left, right, _: InnerLike, Some(cond), hint) =>
+ hint.leftHint.foreach(hintMap.put(left.outputSet, _))
+ hint.rightHint.foreach(hintMap.put(right.outputSet, _))
+ val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap)
+ val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap)
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
leftConditions ++ rightConditions)
- case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond)))
+ case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
if projectList.forall(_.isInstanceOf[Attribute]) =>
- extractInnerJoins(j)
+ extractInnerJoins(j, hintMap)
case _ =>
(Seq(plan), Set())
}
}
private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan
match {
- case j @ Join(left, right, jt: InnerLike, Some(cond)) =>
+ case j @ Join(left, right, jt: InnerLike, Some(cond), _) =>
val replacedLeft = replaceWithOrderedJoin(left)
val replacedRight = replaceWithOrderedJoin(right)
OrderedJoin(replacedLeft, replacedRight, jt, Some(cond))
- case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) =>
+ case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
=>
p.copy(child = replaceWithOrderedJoin(j))
case _ =>
plan
@@ -295,7 +339,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
} else {
(otherPlan, onePlan)
}
- val newJoin = Join(left, right, Inner, joinConds.reduceOption(And))
+ val newJoin = Join(left, right, Inner, joinConds.reduceOption(And),
JoinHint.NONE)
val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++
otherJoinPlan.joinConds
val remainingConds = conditions -- collectedJoinConds
val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++
topOutput
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
new file mode 100644
index 0000000000000..bbe4eee4b4326
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Replaces [[ResolvedHint]] operators from the plan. Move the [[HintInfo]] to
associated [[Join]]
+ * operators, otherwise remove it if no [[Join]] operator is matched.
+ */
+object EliminateResolvedHint extends Rule[LogicalPlan] {
+ // This is also called in the beginning of the optimization phase, and as a
result
+ // is using transformUp rather than resolveOperators.
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val pulledUp = plan transformUp {
+ case j: Join =>
+ val leftHint = mergeHints(collectHints(j.left))
+ val rightHint = mergeHints(collectHints(j.right))
+ j.copy(hint = JoinHint(leftHint, rightHint))
+ }
+ pulledUp.transform {
+ case h: ResolvedHint => h.child
+ }
+ }
+
+ private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = {
+ hints.reduceOption((h1, h2) => HintInfo(
+ broadcast = h1.broadcast || h2.broadcast))
+ }
+
+ private def collectHints(plan: LogicalPlan): Seq[HintInfo] = {
+ plan match {
+ case h: ResolvedHint => collectHints(h.child) :+ h.hints
+ case u: UnaryNode => collectHints(u.child)
+ // TODO revisit this logic:
+ // except and intersect are semi/anti-joins which won't return more data
then
+ // their left argument, so the broadcast hint should be propagated here
+ case i: Intersect => collectHints(i.left)
+ case e: Except => collectHints(e.left)
+ case _ => Seq.empty
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 44d5543114902..06f908281dd3c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -115,6 +115,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
// However, because we also use the analyzer to canonicalized queries (for
view definition),
// we do not eliminate subqueries or compute current time in the analyzer.
Batch("Finish Analysis", Once,
+ EliminateResolvedHint,
EliminateSubqueryAliases,
EliminateView,
ReplaceExpressions,
@@ -192,6 +193,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
*/
def nonExcludableRules: Seq[String] =
EliminateDistinct.ruleName ::
+ EliminateResolvedHint.ruleName ::
EliminateSubqueryAliases.ruleName ::
EliminateView.ruleName ::
ReplaceExpressions.ruleName ::
@@ -356,7 +358,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
// not allowed to use the same attributes. We use a blacklist to prevent
us from creating a
// situation in which this happens; the rule will only remove an alias
if its child
// attribute is not on the black list.
- case Join(left, right, joinType, condition) =>
+ case Join(left, right, joinType, condition, hint) =>
val newLeft = removeRedundantAliases(left, blacklist ++
right.outputSet)
val newRight = removeRedundantAliases(right, blacklist ++
newLeft.outputSet)
val mapping = AttributeMap(
@@ -365,7 +367,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
val newCondition = condition.map(_.transform {
case a: Attribute => mapping.getOrElse(a, a)
})
- Join(newLeft, newRight, joinType, newCondition)
+ Join(newLeft, newRight, joinType, newCondition, hint)
case _ =>
// Remove redundant aliases in the subtree(s).
@@ -460,7 +462,7 @@ object LimitPushDown extends Rule[LogicalPlan] {
// on both sides if it is applied multiple times. Therefore:
// - If one side is already limited, stack another limit on top if the
new limit is smaller.
// The redundant limit will be collapsed by the CombineLimits rule.
- case LocalLimit(exp, join @ Join(left, right, joinType, _)) =>
+ case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) =>
val newJoin = joinType match {
case RightOuter => join.copy(right = maybePushLocalLimit(exp, right))
case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left))
@@ -578,7 +580,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
p.copy(child = g.copy(child = newChild, unrequiredChildIndex =
unrequiredIndices))
// Eliminate unneeded attributes from right side of a Left Existence Join.
- case j @ Join(_, right, LeftExistence(_), _) =>
+ case j @ Join(_, right, LeftExistence(_), _, _) =>
j.copy(right = prunedChild(right, j.references))
// all the columns will be used to compare, so we can't prune them
@@ -792,7 +794,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
filter
}
- case join @ Join(left, right, joinType, conditionOpt) =>
+ case join @ Join(left, right, joinType, conditionOpt, _) =>
joinType match {
// For inner join, we can infer additional filters for both sides.
LeftSemi is kind of an
// inner join, it just drops the right side in the final output.
@@ -919,7 +921,6 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] {
def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
case p: Project => p.projectList.forall(_.deterministic)
case f: Filter => f.condition.deterministic
- case _: ResolvedHint => true
case _ => false
}
}
@@ -1094,7 +1095,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with
PredicateHelper {
// Note that some operators (e.g. project, aggregate, union) are being
handled separately
// (earlier in this rule).
case _: AppendColumns => true
- case _: ResolvedHint => true
case _: Distinct => true
case _: Generate => true
case _: Pivot => true
@@ -1179,7 +1179,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// push the where condition down into join filter
- case f @ Filter(filterCondition, Join(left, right, joinType,
joinCondition)) =>
+ case f @ Filter(filterCondition, Join(left, right, joinType,
joinCondition, hint)) =>
val (leftFilterConditions, rightFilterConditions, commonFilterCondition)
=
split(splitConjunctivePredicates(filterCondition), left, right)
joinType match {
@@ -1193,7 +1193,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
commonFilterCondition.partition(canEvaluateWithinJoin)
val newJoinCond = (newJoinConditions ++
joinCondition).reduceLeftOption(And)
- val join = Join(newLeft, newRight, joinType, newJoinCond)
+ val join = Join(newLeft, newRight, joinType, newJoinCond, hint)
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
@@ -1205,7 +1205,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = joinCondition
- val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
+ val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, hint)
(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
@@ -1215,7 +1215,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = joinCondition
- val newJoin = Join(newLeft, newRight, joinType, newJoinCond)
+ val newJoin = Join(newLeft, newRight, joinType, newJoinCond, hint)
(rightFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
@@ -1225,7 +1225,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
}
// push down the join filter into sub query scanning if applicable
- case j @ Join(left, right, joinType, joinCondition) =>
+ case j @ Join(left, right, joinType, joinCondition, hint) =>
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil),
left, right)
@@ -1238,7 +1238,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = commonJoinCondition.reduceLeftOption(And)
- Join(newLeft, newRight, joinType, newJoinCond)
+ Join(newLeft, newRight, joinType, newJoinCond, hint)
case RightOuter =>
// push down the left side only join filter for left side sub query
val newLeft = leftJoinConditions.
@@ -1246,7 +1246,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
val newRight = right
val newJoinCond = (rightJoinConditions ++
commonJoinCondition).reduceLeftOption(And)
- Join(newLeft, newRight, RightOuter, newJoinCond)
+ Join(newLeft, newRight, RightOuter, newJoinCond, hint)
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
// push down the right side only join filter for right sub query
val newLeft = left
@@ -1254,7 +1254,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan]
with PredicateHelper {
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (leftJoinConditions ++
commonJoinCondition).reduceLeftOption(And)
- Join(newLeft, newRight, joinType, newJoinCond)
+ Join(newLeft, newRight, joinType, newJoinCond, hint)
case FullOuter => j
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
@@ -1310,7 +1310,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan]
with PredicateHelper {
if (SQLConf.get.crossJoinEnabled) {
plan
} else plan transform {
- case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _)
+ case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter,
_, _)
if isCartesianProduct(j) =>
throw new AnalysisException(
s"""Detected implicit cartesian product for ${j.joinType.sql} join
between logical plans
@@ -1449,7 +1449,7 @@ object ReplaceIntersectWithSemiJoin extends
Rule[LogicalPlan] {
case Intersect(left, right, false) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) =>
EqualNullSafe(l, r) }
- Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
+ Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And),
JoinHint.NONE))
}
}
@@ -1470,7 +1470,7 @@ object ReplaceExceptWithAntiJoin extends
Rule[LogicalPlan] {
case Except(left, right, false) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) =>
EqualNullSafe(l, r) }
- Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And)))
+ Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And),
JoinHint.NONE))
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index c3fdb924243df..b19e13870aa65 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -56,7 +56,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with
PredicateHelper wit
// Joins on empty LocalRelations generated from streaming sources are not
eliminated
// as stateful streaming joins need to perform other state management
operations other than
// just processing the input data.
- case p @ Join(_, _, joinType, _)
+ case p @ Join(_, _, joinType, _, _)
if !p.children.exists(_.isStreaming) =>
val isLeftEmpty = isEmptyLocalRelation(p.left)
val isRightEmpty = isEmptyLocalRelation(p.right)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
index 72a60f692ac78..689915a985343 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -52,7 +52,7 @@ object ReplaceNullWithFalseInPredicate extends
Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
- case j @ Join(_, _, _, Some(cond)) => j.copy(condition =
Some(replaceNullWithFalse(cond)))
+ case j @ Join(_, _, _, Some(cond), _) => j.copy(condition =
Some(replaceNullWithFalse(cond)))
case p: LogicalPlan => p transformExpressions {
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
case cw @ CaseWhen(branches, _) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 468a950fb1087..39709529c00d3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -600,7 +600,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
// propagating the foldable expressions.
// TODO(cloud-fan): It seems more reasonable to use new attributes as
the output attributes
// of outer join.
- case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty =>
+ case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
val newJoin = j.transformExpressions(replaceFoldable)
val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
case _: InnerLike | LeftExistence(_) => Nil
@@ -648,7 +648,6 @@ object FoldablePropagation extends Rule[LogicalPlan] {
case _: Distinct => true
case _: AppendColumns => true
case _: AppendColumnsWithObject => true
- case _: ResolvedHint => true
case _: RepartitionByExpression => true
case _: Repartition => true
case _: Sort => true
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 0b6471289a471..82aefca8a1af6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -43,10 +43,13 @@ object ReorderJoin extends Rule[LogicalPlan] with
PredicateHelper {
*
* @param input a list of LogicalPlans to inner join and the type of inner
join.
* @param conditions a list of condition for join.
+ * @param hintMap a map of relation output attribute sets to their
corresponding hints.
*/
@tailrec
- final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)],
conditions: Seq[Expression])
- : LogicalPlan = {
+ final def createOrderedJoin(
+ input: Seq[(LogicalPlan, InnerLike)],
+ conditions: Seq[Expression],
+ hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) =
conditions.partition(canEvaluateWithinJoin)
@@ -55,7 +58,8 @@ object ReorderJoin extends Rule[LogicalPlan] with
PredicateHelper {
case (Inner, Inner) => Inner
case (_, _) => Cross
}
- val join = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And))
+ val join = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And),
+ JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
@@ -78,26 +82,27 @@ object ReorderJoin extends Rule[LogicalPlan] with
PredicateHelper {
val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
- val joined = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And))
+ val joined = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And),
+ JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
// should not have reference to same logical plan
- createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right),
others)
+ createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right),
others, hintMap)
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p @ ExtractFiltersAndInnerJoins(input, conditions)
+ case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap)
if input.size > 2 && conditions.nonEmpty =>
val reordered = if (SQLConf.get.starSchemaDetection &&
!SQLConf.get.cboEnabled) {
val starJoinPlan = StarSchemaDetection.reorderStarJoins(input,
conditions)
if (starJoinPlan.nonEmpty) {
val rest = input.filterNot(starJoinPlan.contains(_))
- createOrderedJoin(starJoinPlan ++ rest, conditions)
+ createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap)
} else {
- createOrderedJoin(input, conditions)
+ createOrderedJoin(input, conditions, hintMap)
}
} else {
- createOrderedJoin(input, conditions)
+ createOrderedJoin(input, conditions, hintMap)
}
if (p.sameOutput(reordered)) {
@@ -156,7 +161,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with
PredicateHelper {
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter |
FullOuter, _)) =>
+ case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter |
FullOuter, _, _)) =>
val newJoinType = buildNewJoinType(f, j)
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType
= newJoinType))
}
@@ -176,7 +181,7 @@ object PullOutPythonUDFInJoinCondition extends
Rule[LogicalPlan] with PredicateH
}
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond,
j) =>
+ case j @ Join(_, _, joinType, Some(cond), _) if
hasUnevaluablePythonUDF(cond, j) =>
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
// The current strategy only support InnerLike and LeftSemi join
because for other type,
// it breaks SQL semantic if we run the join condition as a filter
after join. If we pass
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 34840c6c977a6..e78ed1c3c5d94 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -51,7 +51,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
condition: Option[Expression]): Join = {
// Deduplicate conflicting attributes if any.
val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None,
condition)
- Join(outerPlan, dedupSubplan, joinType, condition)
+ Join(outerPlan, dedupSubplan, joinType, condition, JoinHint.NONE)
}
private def dedupSubqueryOnSelfJoin(
@@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++
conditions, p)
- Join(outerPlan, newSub, LeftSemi, joinCond)
+ Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE)
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is
regarded as a positive
@@ -142,7 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
// will have the final conditions in the LEFT ANTI as
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 >
1
val finalJoinCond = (nullAwareJoinConds ++
conditions).reduceLeft(And)
- Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
+ Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond),
JoinHint.NONE)
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
@@ -172,7 +172,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++
conditions).reduceLeftOption(And)
- newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
+ newPlan = Join(newPlan, newSub, ExistenceJoin(exists),
newConditions, JoinHint.NONE)
exists
}
}
@@ -450,7 +450,7 @@ object RewriteCorrelatedScalarSubquery extends
Rule[LogicalPlan] {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ origOutput,
- Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+ Join(currentChild, query, LeftOuter, conditions.reduceOption(And),
JoinHint.NONE))
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
val (topPart, havingNode, aggNode) = splitSubquery(query)
@@ -477,7 +477,7 @@ object RewriteCorrelatedScalarSubquery extends
Rule[LogicalPlan] {
aggValRef), origOutput.name)(exprId = origOutput.exprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
- LeftOuter, conditions.reduceOption(And)))
+ LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause
above the join.
@@ -507,7 +507,7 @@ object RewriteCorrelatedScalarSubquery extends
Rule[LogicalPlan] {
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
- LeftOuter, conditions.reduceOption(And)))
+ LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 8959f78b656d2..a27c6d3c3671c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -515,7 +515,7 @@ class AstBuilder(conf: SQLConf) extends
SqlBaseBaseVisitor[AnyRef] with Logging
override def visitFromClause(ctx: FromClauseContext): LogicalPlan =
withOrigin(ctx) {
val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left,
relation) =>
val right = plan(relation.relationPrimary)
- val join = right.optionalMap(left)(Join(_, _, Inner, None))
+ val join = right.optionalMap(left)(Join(_, _, Inner, None,
JoinHint.NONE))
withJoinRelations(join, relation)
}
if (ctx.pivotClause() != null) {
@@ -727,7 +727,7 @@ class AstBuilder(conf: SQLConf) extends
SqlBaseBaseVisitor[AnyRef] with Logging
case None =>
(baseJoinType, None)
}
- Join(left, plan(join.right), joinType, condition)
+ Join(left, plan(join.right), joinType, condition, JoinHint.NONE)
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 84be677e438a6..dfc3b2d22129d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.planning
+import scala.collection.mutable
+
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
@@ -98,12 +100,13 @@ object PhysicalOperation extends PredicateHelper {
* value).
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
- /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
+ /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild,
joinHint) */
type ReturnType =
- (JoinType, Seq[Expression], Seq[Expression], Option[Expression],
LogicalPlan, LogicalPlan)
+ (JoinType, Seq[Expression], Seq[Expression],
+ Option[Expression], LogicalPlan, LogicalPlan, JoinHint)
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
- case join @ Join(left, right, joinType, condition) =>
+ case join @ Join(left, right, joinType, condition, hint) =>
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and
thus can be used
// as join keys.
@@ -133,7 +136,7 @@ object ExtractEquiJoinKeys extends Logging with
PredicateHelper {
if (joinKeys.nonEmpty) {
val (leftKeys, rightKeys) = joinKeys.unzip
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
- Some((joinType, leftKeys, rightKeys,
otherPredicates.reduceOption(And), left, right))
+ Some((joinType, leftKeys, rightKeys,
otherPredicates.reduceOption(And), left, right, hint))
} else {
None
}
@@ -164,25 +167,35 @@ object ExtractFiltersAndInnerJoins extends
PredicateHelper {
* was involved in an explicit cross join. Also returns the entire list of
join conditions for
* the left-deep tree.
*/
- def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
+ def flattenJoin(
+ plan: LogicalPlan,
+ hintMap: mutable.HashMap[AttributeSet, HintInfo],
+ parentJoinType: InnerLike = Inner)
: (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
- case Join(left, right, joinType: InnerLike, cond) =>
- val (plans, conditions) = flattenJoin(left, joinType)
+ case Join(left, right, joinType: InnerLike, cond, hint) =>
+ val (plans, conditions) = flattenJoin(left, hintMap, joinType)
+ hint.leftHint.map(hintMap.put(left.outputSet, _))
+ hint.rightHint.map(hintMap.put(right.outputSet, _))
(plans ++ Seq((right, joinType)), conditions ++
cond.toSeq.flatMap(splitConjunctivePredicates))
- case Filter(filterCondition, j @ Join(left, right, _: InnerLike,
joinCondition)) =>
- val (plans, conditions) = flattenJoin(j)
+ case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) =>
+ val (plans, conditions) = flattenJoin(j, hintMap)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))
case _ => (Seq((plan, parentJoinType)), Seq.empty)
}
- def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)],
Seq[Expression])]
+ def unapply(plan: LogicalPlan)
+ : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression],
Map[AttributeSet, HintInfo])]
= plan match {
- case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _)) =>
- Some(flattenJoin(f))
- case j @ Join(_, _, joinType, _) =>
- Some(flattenJoin(j))
+ case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _,
_)) =>
+ val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
+ val flattened = flattenJoin(f, hintMap)
+ Some((flattened._1, flattened._2, hintMap.toMap))
+ case j @ Join(_, _, joinType, _, _) =>
+ val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
+ val flattened = flattenJoin(j, hintMap)
+ Some((flattened._1, flattened._2, hintMap.toMap))
case _ => None
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
index 2c248d74869ce..18baced8f3d61 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
@@ -37,7 +37,6 @@ trait LogicalPlanVisitor[T] {
case p: Project => visitProject(p)
case p: Repartition => visitRepartition(p)
case p: RepartitionByExpression => visitRepartitionByExpr(p)
- case p: ResolvedHint => visitHint(p)
case p: Sample => visitSample(p)
case p: ScriptTransformation => visitScriptTransform(p)
case p: Union => visitUnion(p)
@@ -61,8 +60,6 @@ trait LogicalPlanVisitor[T] {
def visitGlobalLimit(p: GlobalLimit): T
- def visitHint(p: ResolvedHint): T
-
def visitIntersect(p: Intersect): T
def visitJoin(p: Join): T
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index b3a48860aa63b..5a388117a6c0a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -52,13 +52,11 @@ import org.apache.spark.util.Utils
* defaults to the product of children's `sizeInBytes`.
* @param rowCount Estimated number of rows.
* @param attributeStats Statistics for Attributes.
- * @param hints Query hints.
*/
case class Statistics(
sizeInBytes: BigInt,
rowCount: Option[BigInt] = None,
- attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil),
- hints: HintInfo = HintInfo()) {
+ attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil)) {
override def toString: String = "Statistics(" + simpleString + ")"
@@ -70,8 +68,7 @@ case class Statistics(
s"rowCount=${BigDecimal(rowCount.get, new MathContext(3,
RoundingMode.HALF_UP)).toString()}"
} else {
""
- },
- s"hints=$hints"
+ }
).filter(_.nonEmpty).mkString(", ")
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index d8b3a4af4f7bf..639d68f4ecd76 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -288,7 +288,8 @@ case class Join(
left: LogicalPlan,
right: LogicalPlan,
joinType: JoinType,
- condition: Option[Expression])
+ condition: Option[Expression],
+ hint: JoinHint)
extends BinaryNode with PredicateHelper {
override def output: Seq[Attribute] = {
@@ -350,6 +351,17 @@ case class Join(
case UsingJoin(_, _) => false
case _ => resolvedExceptNatural
}
+
+ // Ignore hint for canonicalization
+ protected override def doCanonicalize(): LogicalPlan =
+ super.doCanonicalize().asInstanceOf[Join].copy(hint = JoinHint.NONE)
+
+ // Do not include an empty join hint in string description
+ protected override def stringArgs: Iterator[Any] = super.stringArgs.filter {
e =>
+ (!e.isInstanceOf[JoinHint]
+ || e.asInstanceOf[JoinHint].leftHint.isDefined
+ || e.asInstanceOf[JoinHint].rightHint.isDefined)
+ }
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index cbb626590d1d7..b2ba725e9d44f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -35,6 +35,7 @@ case class UnresolvedHint(name: String, parameters: Seq[Any],
child: LogicalPlan
/**
* A resolved hint node. The analyzer should convert all [[UnresolvedHint]]
into [[ResolvedHint]].
+ * This node will be eliminated before optimization starts.
*/
case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
extends UnaryNode {
@@ -44,11 +45,31 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo
= HintInfo())
override def doCanonicalize(): LogicalPlan = child.canonicalized
}
+/**
+ * Hint that is associated with a [[Join]] node, with [[HintInfo]] on its left
child and on its
+ * right child respectively.
+ */
+case class JoinHint(leftHint: Option[HintInfo], rightHint: Option[HintInfo]) {
-case class HintInfo(broadcast: Boolean = false) {
+ override def toString: String = {
+ Seq(
+ leftHint.map("leftHint=" + _),
+ rightHint.map("rightHint=" + _))
+ .filter(_.isDefined).map(_.get).mkString(", ")
+ }
+}
- /** Must be called when computing stats for a join operator to reset hints.
*/
- def resetForJoin(): HintInfo = copy(broadcast = false)
+object JoinHint {
+ val NONE = JoinHint(None, None)
+}
+
+/**
+ * The hint attributes to be applied on a specific node.
+ *
+ * @param broadcast If set to true, it indicates that the broadcast hash join
is the preferred join
+ * strategy and the node with this hint is preferred to be
the build side.
+ */
+case class HintInfo(broadcast: Boolean = false) {
override def toString: String = {
val hints = scala.collection.mutable.ArrayBuffer.empty[String]
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index 111c594a53e52..eb56ab43ea9d5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -56,8 +56,7 @@ object AggregateEstimation {
Some(Statistics(
sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
rowCount = Some(outputRows),
- attributeStats = outputAttrStats,
- hints = childStats.hints))
+ attributeStats = outputAttrStats))
} else {
None
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
index b6c16079d1984..b8c652dc8f12e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
@@ -47,8 +47,6 @@ object BasicStatsPlanVisitor extends
LogicalPlanVisitor[Statistics] {
override def visitGlobalLimit(p: GlobalLimit): Statistics = fallback(p)
- override def visitHint(p: ResolvedHint): Statistics = fallback(p)
-
override def visitIntersect(p: Intersect): Statistics = fallback(p)
override def visitJoin(p: Join): Statistics = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index 2543e38a92c0a..19a0d1279cc32 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging {
case _ if !rowCountsExist(join.left, join.right) =>
None
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) =>
// 1. Compute join selectivity
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
val (numInnerJoinedRows, keyStatsAfterJoin) =
computeCardinalityAndStats(joinKeyPairs)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
index ee43f9126386b..da36db7ae1f5f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
@@ -44,7 +44,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends
LogicalPlanVisitor[Statistics] {
}
// Don't propagate rowCount and attributeStats, since they are not
estimated here.
- Statistics(sizeInBytes = sizeInBytes, hints = p.child.stats.hints)
+ Statistics(sizeInBytes = sizeInBytes)
}
/**
@@ -60,8 +60,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends
LogicalPlanVisitor[Statistics] {
if (p.groupingExpressions.isEmpty) {
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(p.output, outputRowCount =
1),
- rowCount = Some(1),
- hints = p.child.stats.hints)
+ rowCount = Some(1))
} else {
visitUnaryNode(p)
}
@@ -87,19 +86,15 @@ object SizeInBytesOnlyStatsPlanVisitor extends
LogicalPlanVisitor[Statistics] {
// Don't propagate column stats, because we don't know the distribution
after limit
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(p.output, rowCount,
childStats.attributeStats),
- rowCount = Some(rowCount),
- hints = childStats.hints)
+ rowCount = Some(rowCount))
}
- override def visitHint(p: ResolvedHint): Statistics =
p.child.stats.copy(hints = p.hints)
-
override def visitIntersect(p: Intersect): Statistics = {
val leftSize = p.left.stats.sizeInBytes
val rightSize = p.right.stats.sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
Statistics(
- sizeInBytes = sizeInBytes,
- hints = p.left.stats.hints.resetForJoin())
+ sizeInBytes = sizeInBytes)
}
override def visitJoin(p: Join): Statistics = {
@@ -108,10 +103,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends
LogicalPlanVisitor[Statistics] {
// LeftSemi and LeftAnti won't ever be bigger than left
p.left.stats
case _ =>
- // Make sure we don't propagate isBroadcastable in other joins, because
- // they could explode the size.
- val stats = default(p)
- stats.copy(hints = stats.hints.resetForJoin())
+ default(p)
}
}
@@ -121,7 +113,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends
LogicalPlanVisitor[Statistics] {
if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be
zero
// (product of children).
- Statistics(sizeInBytes = 1, rowCount = Some(0), hints = childStats.hints)
+ Statistics(sizeInBytes = 1, rowCount = Some(0))
} else {
// The output row count of LocalLimit should be the sum of row counts
from each partition.
// However, since the number of partitions is not available here, we
just use statistics of
@@ -147,7 +139,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends
LogicalPlanVisitor[Statistics] {
}
val sampleRows = p.child.stats.rowCount.map(c =>
EstimationUtils.ceil(BigDecimal(c) * ratio))
// Don't propagate column stats, because we don't know the distribution
after a sample operation
- Statistics(sizeInBytes, sampleRows, hints = p.child.stats.hints)
+ Statistics(sizeInBytes, sampleRows)
}
override def visitScriptTransform(p: ScriptTransformation): Statistics =
default(p)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 117e96175e92a..129ce3b1105ee 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -443,7 +443,7 @@ class AnalysisErrorSuite extends AnalysisTest {
}
test("error test for self-join") {
- val join = Join(testRelation, testRelation, Cross, None)
+ val join = Join(testRelation, testRelation, Cross, None, JoinHint.NONE)
val error = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(join)
}
@@ -565,7 +565,8 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(b),
Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
LeftOuter,
- Option(EqualTo(b, c)))),
+ Option(EqualTo(b, c)),
+ JoinHint.NONE)),
LocalRelation(a))
assertAnalysisError(plan1, "Accessing outer query column is not allowed
in" :: Nil)
@@ -575,7 +576,8 @@ class AnalysisErrorSuite extends AnalysisTest {
Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
LocalRelation(b),
RightOuter,
- Option(EqualTo(b, c)))),
+ Option(EqualTo(b, c)),
+ JoinHint.NONE)),
LocalRelation(a))
assertAnalysisError(plan2, "Accessing outer query column is not allowed
in" :: Nil)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index da3ae72c3682a..982948483fa1c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -397,7 +397,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
Project(Seq($"y.key"), SubqueryAlias("y", input)),
- Cross, None))
+ Cross, None, JoinHint.NONE))
assertAnalysisSuccess(query)
}
@@ -578,7 +578,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
Seq(UnresolvedAttribute("a")), pythonUdf, output, project)
val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
- val join = Join(left, right, Inner, None)
+ val join = Join(left, right, Inner, None, JoinHint.NONE)
assertAnalysisSuccess(
Project(Seq(UnresolvedAttribute("temp0.a"),
UnresolvedAttribute("temp1.a")), join))
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index bd66ee5355f45..563e8adf87edc 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -60,7 +60,7 @@ class ResolveHintsSuite extends AnalysisTest {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"),
table("table").join(table("table2"))),
Join(ResolvedHint(testRelation, HintInfo(broadcast = true)),
- ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None),
+ ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None,
JoinHint.NONE),
caseSensitive = false)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 57195d5fda7c5..0cd6e092e2036 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -353,15 +353,15 @@ class ColumnPruningSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
SubqueryAlias("x", input),
- ResolvedHint(SubqueryAlias("y", input)), Inner, None)).analyze
+ SubqueryAlias("y", input), Inner, None, JoinHint.NONE)).analyze
val optimized = Optimize.execute(query)
val expected =
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
- ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
- Inner, None).analyze
+ Project(Seq($"y.key"), SubqueryAlias("y", input)),
+ Inner, None, JoinHint.NONE).analyze
comparePlans(optimized, expected)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 82a10254d846d..cf4e9fcea2c6d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -822,19 +821,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
- test("broadcast hint") {
- val originalQuery = ResolvedHint(testRelation)
- .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
-
- val optimized = Optimize.execute(originalQuery.analyze)
-
- val correctAnswer = ResolvedHint(testRelation.where('a === 2L))
- .where('b + Rand(10).as("rnd") === 3)
- .analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
test("union") {
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 6fe5e619d03ad..9093d7fecb0f7 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -65,7 +65,8 @@ class JoinOptimizationSuite extends PlanTest {
def testExtractCheckCross
(plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)],
Seq[Expression])]) {
- assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected)
+ assert(
+ ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1,
e._2, Map.empty)))
}
testExtract(x, None)
@@ -124,29 +125,4 @@ class JoinOptimizationSuite extends PlanTest {
comparePlans(optimized, queryAnswerPair._2.analyze)
}
}
-
- test("broadcasthint sets relation statistics to smallest value") {
- val input = LocalRelation('key.int, 'value.string)
-
- val query =
- Project(Seq($"x.key", $"y.key"),
- Join(
- SubqueryAlias("x", input),
- ResolvedHint(SubqueryAlias("y", input)), Cross, None)).analyze
-
- val optimized = Optimize.execute(query)
-
- val expected =
- Join(
- Project(Seq($"x.key"), SubqueryAlias("x", input)),
- ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
- Cross, None).analyze
-
- comparePlans(optimized, expected)
-
- val broadcastChildren = optimized.collect {
- case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r
- }
- assert(broadcastChildren.size == 1)
- }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index c94a8b9e318f6..0dee846205868 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -31,6 +31,8 @@ class JoinReorderSuite extends PlanTest with
StatsEstimationTestBase {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
+ Batch("Resolve Hints", Once,
+ EliminateResolvedHint) ::
Batch("Operator Optimizations", FixedPoint(100),
CombineFilters,
PushDownPredicate,
@@ -42,6 +44,12 @@ class JoinReorderSuite extends PlanTest with
StatsEstimationTestBase {
CostBasedJoinReorder) :: Nil
}
+ object ResolveHints extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Resolve Hints", Once,
+ EliminateResolvedHint) :: Nil
+ }
+
var originalConfCBOEnabled = false
var originalConfJoinReorderEnabled = false
@@ -284,12 +292,85 @@ class JoinReorderSuite extends PlanTest with
StatsEstimationTestBase {
assertEqualPlans(originalPlan, bestPlan)
}
+ test("hints preservation") {
+ // Apply hints if we find an equivalent node in the new plan, otherwise
discard them.
+ val originalPlan =
+
t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast"))
+ .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+ (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ val bestPlan =
+ t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") ===
nameToAttr("t2.k-1-5")))
+ .hint("broadcast")
+ .join(
+ t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") ===
nameToAttr("t3.v-1-100")))
+ .hint("broadcast"),
+ Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+
+ assertEqualPlans(originalPlan, bestPlan)
+
+ val originalPlan2 =
+
t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast"))
+ .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+ (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ val bestPlan2 =
+ t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") ===
nameToAttr("t2.k-1-5")))
+ .hint("broadcast")
+ .join(
+ t4.hint("broadcast")
+ .join(t3, Inner, Some(nameToAttr("t4.v-1-10") ===
nameToAttr("t3.v-1-100"))),
+ Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+ .select(outputsOf(t1, t2, t3, t4): _*)
+
+ assertEqualPlans(originalPlan2, bestPlan2)
+
+ val originalPlan3 =
+ t1.join(t4).hint("broadcast")
+ .join(t2.hint("broadcast")).hint("broadcast")
+ .join(t3.hint("broadcast"))
+ .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+ (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ val bestPlan3 =
+ t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") ===
nameToAttr("t2.k-1-5")))
+ .join(
+ t4.join(t3.hint("broadcast"),
+ Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
+ Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+ .select(outputsOf(t1, t4, t2, t3): _*)
+
+ assertEqualPlans(originalPlan3, bestPlan3)
+
+ val originalPlan4 =
+ t2.hint("broadcast")
+ .join(t4).hint("broadcast")
+ .join(t3.hint("broadcast")).hint("broadcast")
+ .join(t1)
+ .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+ (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ val bestPlan4 =
+ t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") ===
nameToAttr("t2.k-1-5")))
+ .join(
+ t4.join(t3.hint("broadcast"),
+ Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
+ Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+ .select(outputsOf(t2, t4, t3, t1): _*)
+
+ assertEqualPlans(originalPlan4, bestPlan4)
+ }
+
private def assertEqualPlans(
originalPlan: LogicalPlan,
groundTruthBestPlan: LogicalPlan): Unit = {
val analyzed = originalPlan.analyze
val optimized = Optimize.execute(analyzed)
- val expected = groundTruthBestPlan.analyze
+ val expected = ResolveHints.execute(groundTruthBestPlan.analyze)
assert(analyzed.sameOutput(expected)) // if this fails, the expected plan
itself is incorrect
assert(analyzed.sameOutput(optimized))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index c8e15c7da763e..6d1af12e68b23 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -48,7 +48,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze
+ Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd),
JoinHint.NONE)).analyze
comparePlans(optimized, correctAnswer)
}
@@ -160,7 +160,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze
+ Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd),
JoinHint.NONE)).analyze
comparePlans(optimized, correctAnswer)
}
@@ -175,7 +175,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(left.output, right.output,
- Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze
+ Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"),
JoinHint.NONE)).analyze
comparePlans(optimized, correctAnswer)
}
@@ -248,7 +248,7 @@ class ReplaceOperatorSuite extends PlanTest {
val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2)
=>
a1 <=> a2 }.reduce( _ && _)
val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
- Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+ Join(basePlan, otherPlan, LeftAnti, Option(condition),
JoinHint.NONE)).analyze
comparePlans(result, correctAnswer)
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 3081ff935f043..5394732f41f2d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -99,11 +99,11 @@ trait PlanTestBase extends PredicateHelper with SQLHelper {
self: Suite =>
.reduce(And), child)
case sample: Sample =>
sample.copy(seed = 0L)
- case Join(left, right, joinType, condition) if condition.isDefined =>
+ case Join(left, right, joinType, condition, hint) if condition.isDefined
=>
val newCondition =
splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode())
.reduce(And)
- Join(left, right, joinType, Some(newCondition))
+ Join(left, right, joinType, Some(newCondition), hint)
}
}
@@ -165,8 +165,10 @@ trait PlanTestBase extends PredicateHelper with SQLHelper
{ self: Suite =>
private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
(plan1, plan2) match {
case (j1: Join, j2: Join) =>
- (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
- (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
+ (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)
+ && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint ==
j2.hint.rightHint) ||
+ (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)
+ && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint ==
j2.hint.leftHint)
case (p1: Project, p2: Project) =>
p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child)
case _ =>
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
index 7c8ed78a49116..fbaaf807af5d6 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation,
LogicalPlan, ResolvedHint, Union}
+import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util._
/**
@@ -30,6 +32,10 @@ class SameResultSuite extends SparkFunSuite {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("EliminateResolvedHint", Once, EliminateResolvedHint)
:: Nil
+ }
+
def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean =
true): Unit = {
val aAnalyzed = a.analyze
val bAnalyzed = b.analyze
@@ -72,4 +78,12 @@ class SameResultSuite extends SparkFunSuite {
val df2 = testRelation.join(testRelation)
assertSameResult(df1, df2)
}
+
+ test("join hint") {
+ val df1 = testRelation.join(testRelation.hint("broadcast"))
+ val df2 = testRelation.join(testRelation)
+ val df1Optimized = Optimize.execute(df1.analyze)
+ val df2Optimized = Optimize.execute(df2.analyze)
+ assertSameResult(df1Optimized, df2Optimized)
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index 953094cb0dd52..16a5c2d3001a7 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -38,24 +38,6 @@ class BasicStatsEstimationSuite extends PlanTest with
StatsEstimationTestBase {
// row count * (overhead + column size)
size = Some(10 * (8 + 4)))
- test("BroadcastHint estimation") {
- val filter = Filter(Literal(true), plan)
- val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4),
- rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute ->
colStat)))
- val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4))
- checkStats(
- filter,
- expectedStatsCboOn = filterStatsCboOn,
- expectedStatsCboOff = filterStatsCboOff)
-
- val broadcastHint = ResolvedHint(filter, HintInfo(broadcast = true))
- checkStats(
- broadcastHint,
- expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(broadcast =
true)),
- expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(broadcast
= true))
- )
- }
-
test("range") {
val range = Range(1, 5, 1, None)
val rangeStats = Statistics(sizeInBytes = 4 * 8)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index b0a47e7835129..1cf888519077a 100755
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -528,7 +528,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase
{
rowCount = 30,
attributeStats = AttributeMap(Seq(attrIntLargerRange ->
colStatIntLargerRange)))
val nonLeafChild = Join(largerTable, smallerTable, LeftOuter,
- Some(EqualTo(attrIntLargerRange, attrInt)))
+ Some(EqualTo(attrIntLargerRange, attrInt)), JoinHint.NONE)
Seq(IsNull(attrIntLargerRange), IsNotNull(attrIntLargerRange)).foreach {
predicate =>
validateEstimatedStats(
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index 12c0a7be21292..6c5a2b247fc23 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -79,8 +79,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax)
val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax)
- val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)))
- val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)))
+ val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)),
JoinHint.NONE)
+ val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)),
JoinHint.NONE)
val expectedStatsAfterJoin = Statistics(
sizeInBytes = expectedRows * (8 + 2 * 4),
rowCount = Some(expectedRows),
@@ -284,7 +284,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
test("cross join") {
// table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5,
5)
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
- val join = Join(table1, table2, Cross, None)
+ val join = Join(table1, table2, Cross, None, JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = 5 * 3 * (8 + 4 * 4),
rowCount = Some(5 * 3),
@@ -299,7 +299,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
// key-5-9 and key-2-4 are disjoint
val join = Join(table1, table2, Inner,
- Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+ Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))),
JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = 1,
rowCount = Some(0),
@@ -312,7 +312,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
// key-5-9 and key-2-4 are disjoint
val join = Join(table1, table2, LeftOuter,
- Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+ Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))),
JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = 5 * (8 + 4 * 4),
rowCount = Some(5),
@@ -328,7 +328,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
// key-5-9 and key-2-4 are disjoint
val join = Join(table1, table2, RightOuter,
- Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+ Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))),
JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = 3 * (8 + 4 * 4),
rowCount = Some(3),
@@ -344,7 +344,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
// key-5-9 and key-2-4 are disjoint
val join = Join(table1, table2, FullOuter,
- Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+ Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))),
JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = (5 + 3) * (8 + 4 * 4),
rowCount = Some(5 + 3),
@@ -361,7 +361,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5,
5)
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
val join = Join(table1, table2, Inner,
- Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2"))))
+ Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2"))),
JoinHint.NONE)
// Update column stats for equi-join keys (key-1-5 and key-1-2).
val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max
= Some(2),
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
@@ -383,7 +383,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
val join = Join(table2, table3, Inner, Some(
And(EqualTo(nameToAttr("key-1-2"), nameToAttr("key-1-2")),
- EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))))
+ EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))), JoinHint.NONE)
// Update column stats for join keys.
val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1),
max = Some(2),
@@ -404,7 +404,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
// table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
val join = Join(table3, table2, LeftOuter,
- Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4"))))
+ Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4"))),
JoinHint.NONE)
val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max
= Some(3),
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
@@ -422,7 +422,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
// table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
val join = Join(table2, table3, RightOuter,
- Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))
+ Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))),
JoinHint.NONE)
val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max
= Some(3),
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
@@ -440,7 +440,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
// table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
val join = Join(table2, table3, FullOuter,
- Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))
+ Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))),
JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = 3 * (8 + 4 * 4),
@@ -456,7 +456,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
Seq(LeftSemi, LeftAnti).foreach { jt =>
val join = Join(table2, table3, jt,
- Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))
+ Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))),
JoinHint.NONE)
// For now we just propagate the statistics from left side for left
semi/anti join.
val expectedStats = Statistics(
sizeInBytes = 3 * (8 + 4 * 2),
@@ -525,7 +525,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
withClue(s"For data type ${key1.dataType}") {
// All values in two tables are the same, so column stats after join
are also the same.
val join = Join(Project(Seq(key1), table1), Project(Seq(key2),
table2), Inner,
- Some(EqualTo(key1, key2)))
+ Some(EqualTo(key1, key2)), JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))),
rowCount = Some(1),
@@ -543,7 +543,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
outputList = Seq(nullColumn),
rowCount = 1,
attributeStats = AttributeMap(Seq(nullColumn -> nullColStat)))
- val join = Join(table1, nullTable, Inner,
Some(EqualTo(nameToAttr("key-1-5"), nullColumn)))
+ val join = Join(table1, nullTable, Inner,
+ Some(EqualTo(nameToAttr("key-1-5"), nullColumn)), JoinHint.NONE)
val expectedStats = Statistics(
sizeInBytes = 1,
rowCount = Some(0),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a664c7338badb..44cada086489a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -862,7 +862,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def join(right: Dataset[_]): DataFrame = withPlan {
- Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE)
}
/**
@@ -940,7 +940,7 @@ class Dataset[T] private[sql](
// Analyze the self join. The assumption is that the analyzer will
disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sparkSession.sessionState.executePlan(
- Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType),
None))
+ Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType),
None, JoinHint.NONE))
.analyzed.asInstanceOf[Join]
withPlan {
@@ -948,7 +948,8 @@ class Dataset[T] private[sql](
joined.left,
joined.right,
UsingJoin(JoinType(joinType), usingColumns),
- None)
+ None,
+ JoinHint.NONE)
}
}
@@ -1001,7 +1002,7 @@ class Dataset[T] private[sql](
// Trigger analysis so in the case of self-join, the analyzer will clone
the plan.
// After the cloning, left and right side will have distinct expression
ids.
val plan = withPlan(
- Join(logicalPlan, right.logicalPlan, JoinType(joinType),
Some(joinExprs.expr)))
+ Join(logicalPlan, right.logicalPlan, JoinType(joinType),
Some(joinExprs.expr), JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]
// If auto self join alias is disabled, return the plan.
@@ -1048,7 +1049,7 @@ class Dataset[T] private[sql](
* @since 2.1.0
*/
def crossJoin(right: Dataset[_]): DataFrame = withPlan {
- Join(logicalPlan, right.logicalPlan, joinType = Cross, None)
+ Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE)
}
/**
@@ -1083,7 +1084,8 @@ class Dataset[T] private[sql](
this.logicalPlan,
other.logicalPlan,
JoinType(joinType),
- Some(condition.expr))).analyzed.asInstanceOf[Join]
+ Some(condition.expr),
+ JoinHint.NONE)).analyzed.asInstanceOf[Join]
if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
throw new AnalysisException("Invalid join type in joinWith: " +
joined.joinType.sql)
@@ -1135,7 +1137,7 @@ class Dataset[T] private[sql](
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
- withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr)))
+ withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr),
JoinHint.NONE))
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index dbc6db62bd820..b7cc373b2df12 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -208,17 +208,17 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
}
}
- private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan,
right: LogicalPlan)
- : Boolean = {
- val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
- val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+ private def canBroadcastByHints(
+ joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint:
JoinHint): Boolean = {
+ val buildLeft = canBuildLeft(joinType) &&
hint.leftHint.exists(_.broadcast)
+ val buildRight = canBuildRight(joinType) &&
hint.rightHint.exists(_.broadcast)
buildLeft || buildRight
}
- private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan,
right: LogicalPlan)
- : BuildSide = {
- val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
- val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+ private def broadcastSideByHints(
+ joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint:
JoinHint): BuildSide = {
+ val buildLeft = canBuildLeft(joinType) &&
hint.leftHint.exists(_.broadcast)
+ val buildRight = canBuildRight(joinType) &&
hint.rightHint.exists(_.broadcast)
broadcastSide(buildLeft, buildRight, left, right)
}
@@ -241,14 +241,14 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
// --- BroadcastHashJoin
--------------------------------------------------------------------
// broadcast hints were specified
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right)
- if canBroadcastByHints(joinType, left, right) =>
- val buildSide = broadcastSideByHints(joinType, left, right)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, hint)
+ if canBroadcastByHints(joinType, left, right, hint) =>
+ val buildSide = broadcastSideByHints(joinType, left, right, hint)
Seq(joins.BroadcastHashJoinExec(
leftKeys, rightKeys, joinType, buildSide, condition,
planLater(left), planLater(right)))
// broadcast hints were not specified, so need to infer it from size and
configuration.
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
if canBroadcastBySizes(joinType, left, right) =>
val buildSide = broadcastSideBySizes(joinType, left, right)
Seq(joins.BroadcastHashJoinExec(
@@ -256,14 +256,14 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
// --- ShuffledHashJoin
---------------------------------------------------------------------
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
if !conf.preferSortMergeJoin && canBuildRight(joinType) &&
canBuildLocalHashMap(right)
&& muchSmaller(right, left) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoinExec(
leftKeys, rightKeys, joinType, BuildRight, condition,
planLater(left), planLater(right)))
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
if !conf.preferSortMergeJoin && canBuildLeft(joinType) &&
canBuildLocalHashMap(left)
&& muchSmaller(left, right) ||
!RowOrdering.isOrderable(leftKeys) =>
@@ -272,7 +272,7 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
// --- SortMergeJoin
------------------------------------------------------------
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left,
right, _)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoinExec(
leftKeys, rightKeys, joinType, condition, planLater(left),
planLater(right)) :: Nil
@@ -280,25 +280,25 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
// --- Without joining keys
------------------------------------------------------------
// Pick BroadcastNestedLoopJoin if one side could be broadcast
- case j @ logical.Join(left, right, joinType, condition)
- if canBroadcastByHints(joinType, left, right) =>
- val buildSide = broadcastSideByHints(joinType, left, right)
+ case j @ logical.Join(left, right, joinType, condition, hint)
+ if canBroadcastByHints(joinType, left, right, hint) =>
+ val buildSide = broadcastSideByHints(joinType, left, right, hint)
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition)
:: Nil
- case j @ logical.Join(left, right, joinType, condition)
+ case j @ logical.Join(left, right, joinType, condition, _)
if canBroadcastBySizes(joinType, left, right) =>
val buildSide = broadcastSideBySizes(joinType, left, right)
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition)
:: Nil
// Pick CartesianProduct for InnerJoin
- case logical.Join(left, right, _: InnerLike, condition) =>
+ case logical.Join(left, right, _: InnerLike, condition, _) =>
joins.CartesianProductExec(planLater(left), planLater(right),
condition) :: Nil
- case logical.Join(left, right, joinType, condition) =>
+ case logical.Join(left, right, joinType, condition, hint) =>
val buildSide = broadcastSide(
- left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
+ hint.leftHint.exists(_.broadcast),
hint.rightHint.exists(_.broadcast), left, right)
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition)
:: Nil
@@ -380,13 +380,13 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
object StreamingJoinStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
plan match {
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition,
left, right)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition,
left, right, _)
if left.isStreaming && right.isStreaming =>
new StreamingSymmetricHashJoinExec(
leftKeys, rightKeys, joinType, condition, planLater(left),
planLater(right)) :: Nil
- case Join(left, right, _, _) if left.isStreaming && right.isStreaming
=>
+ case Join(left, right, _, _, _) if left.isStreaming &&
right.isStreaming =>
throw new AnalysisException(
"Stream-stream join without equality predicate is not supported",
plan = Some(plan))
@@ -561,6 +561,9 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical except (all) operator should have been replaced by union,
aggregate" +
" and generate operators in the optimizer")
+ case logical.ResolvedHint(child, hints) =>
+ throw new IllegalStateException(
+ "ResolvedHint operator should have been replaced by join hint in the
optimizer")
case logical.DeserializeToObject(deserializer, objAttr, child) =>
execution.DeserializeToObjectExec(deserializer, objAttr,
planLater(child)) :: Nil
@@ -632,7 +635,6 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case ExternalRDD(outputObjAttr, rdd) =>
ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning,
r.outputOrdering) :: Nil
- case h: ResolvedHint => planLater(h.child) :: Nil
case _ => Nil
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 4109d9994dd8f..41f406d6c2993 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -26,7 +26,7 @@ import
org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan,
Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel
@@ -184,12 +184,7 @@ case class InMemoryRelation(
override def computeStats(): Statistics = {
if (cacheBuilder.sizeInBytesStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from
the plan to cache.
- // Note that we should drop the hint info here. We may cache a plan
whose root node is a hint
- // node. When we lookup the cache with a semantically same plan without
hint info, the plan
- // returned by cache lookup should not have hint info. If we lookup the
cache with a
- // semantically same plan with a different hint info,
`CacheManager.useCachedData` will take
- // care of it and retain the hint info in the lookup input plan.
- statsOfPlanToCache.copy(hints = HintInfo())
+ statsOfPlanToCache
} else {
Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 6e805c4f3c39a..2141be4d680f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -27,9 +27,11 @@ import
org.apache.spark.executor.DataReadMethod.DataReadMethod
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
+import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
@@ -925,4 +927,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
with SharedSQLContext
}
}
}
+
+ test("Cache should respect the broadcast hint") {
+ val df = broadcast(spark.range(1000)).cache()
+ val df2 = spark.range(1000).cache()
+ df.count()
+ df2.count()
+
+ // Test the broadcast hint.
+ val joinPlan = df.join(df2, "id").queryExecution.optimizedPlan
+ val hint = joinPlan.collect {
+ case Join(_, _, _, _, hint) => hint
+ }
+ assert(hint.size == 1)
+ assert(hint(0).leftHint.get.broadcast)
+ assert(hint(0).rightHint.isEmpty)
+
+ // Clean-up
+ df.unpersist()
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c9f41ab1c0179..a4a3e2a62d1a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -198,7 +198,7 @@ class DataFrameJoinSuite extends QueryTest with
SharedSQLContext {
// outer -> left
val outerJoin2Left = df.join(df2, $"a.int" === $"b.int",
"outer").where($"a.int" >= 3)
assert(outerJoin2Left.queryExecution.optimizedPlan.collect {
- case j @ Join(_, _, LeftOuter, _) => j }.size === 1)
+ case j @ Join(_, _, LeftOuter, _, _) => j }.size === 1)
checkAnswer(
outerJoin2Left,
Row(3, 4, "3", null, null, null) :: Nil)
@@ -206,7 +206,7 @@ class DataFrameJoinSuite extends QueryTest with
SharedSQLContext {
// outer -> right
val outerJoin2Right = df.join(df2, $"a.int" === $"b.int",
"outer").where($"b.int" >= 3)
assert(outerJoin2Right.queryExecution.optimizedPlan.collect {
- case j @ Join(_, _, RightOuter, _) => j }.size === 1)
+ case j @ Join(_, _, RightOuter, _, _) => j }.size === 1)
checkAnswer(
outerJoin2Right,
Row(null, null, null, 5, 6, "5") :: Nil)
@@ -215,7 +215,7 @@ class DataFrameJoinSuite extends QueryTest with
SharedSQLContext {
val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer").
where($"a.int" === 1 && $"b.int2" === 3)
assert(outerJoin2Inner.queryExecution.optimizedPlan.collect {
- case j @ Join(_, _, Inner, _) => j }.size === 1)
+ case j @ Join(_, _, Inner, _, _) => j }.size === 1)
checkAnswer(
outerJoin2Inner,
Row(1, 2, "1", 1, 3, "1") :: Nil)
@@ -223,7 +223,7 @@ class DataFrameJoinSuite extends QueryTest with
SharedSQLContext {
// right -> inner
val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int",
"right").where($"a.int" > 0)
assert(rightJoin2Inner.queryExecution.optimizedPlan.collect {
- case j @ Join(_, _, Inner, _) => j }.size === 1)
+ case j @ Join(_, _, Inner, _, _) => j }.size === 1)
checkAnswer(
rightJoin2Inner,
Row(1, 2, "1", 1, 3, "1") :: Nil)
@@ -231,7 +231,7 @@ class DataFrameJoinSuite extends QueryTest with
SharedSQLContext {
// left -> inner
val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int",
"left").where($"b.int2" > 0)
assert(leftJoin2Inner.queryExecution.optimizedPlan.collect {
- case j @ Join(_, _, Inner, _) => j }.size === 1)
+ case j @ Join(_, _, Inner, _, _) => j }.size === 1)
checkAnswer(
leftJoin2Inner,
Row(1, 2, "1", 1, 3, "1") :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
new file mode 100644
index 0000000000000..3652895ff43d8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class JoinHintSuite extends PlanTest with SharedSQLContext {
+ import testImplicits._
+
+ lazy val df = spark.range(10)
+ lazy val df1 = df.selectExpr("id as a1", "id as a2")
+ lazy val df2 = df.selectExpr("id as b1", "id as b2")
+ lazy val df3 = df.selectExpr("id as c1", "id as c2")
+
+ def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = {
+ val optimized = df.queryExecution.optimizedPlan
+ val joinHints = optimized collect {
+ case Join(_, _, _, _, hint) => hint
+ case _: ResolvedHint => fail("ResolvedHint should not appear after
optimize.")
+ }
+ assert(joinHints == expectedHints)
+ }
+
+ test("single join") {
+ verifyJoinHint(
+ df.hint("broadcast").join(df, "id"),
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) :: Nil
+ )
+ verifyJoinHint(
+ df.join(df.hint("broadcast"), "id"),
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) :: Nil
+ )
+ }
+
+ test("multiple joins") {
+ verifyJoinHint(
+ df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"),
'a1 === 'c1),
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) ::
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) :: Nil
+ )
+ verifyJoinHint(
+ df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3,
'a1 === 'c1),
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) ::
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) :: Nil
+ )
+ }
+
+ test("hint scope") {
+ withTempView("a", "b", "c") {
+ df1.createOrReplaceTempView("a")
+ df2.createOrReplaceTempView("b")
+ verifyJoinHint(
+ sql(
+ """
+ |select /*+ broadcast(a, b)*/ * from (
+ | select /*+ broadcast(b)*/ * from a join b on a.a1 = b.b1
+ |) a join (
+ | select /*+ broadcast(a)*/ * from a join b on a.a1 = b.b1
+ |) b on a.a1 = b.b1
+ """.stripMargin),
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ Some(HintInfo(broadcast = true))) ::
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) ::
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) :: Nil
+ )
+ }
+ }
+
+ test("hint preserved after join reorder") {
+ withTempView("a", "b", "c") {
+ df1.createOrReplaceTempView("a")
+ df2.createOrReplaceTempView("b")
+ df3.createOrReplaceTempView("c")
+ verifyJoinHint(
+ sql("select /*+ broadcast(a, c)*/ * from a, b, c " +
+ "where a.a1 = b.b1 and b.b1 = c.c1"),
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) ::
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None):: Nil
+ )
+ verifyJoinHint(
+ sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
+ "where a.a1 = b.b1 and b.b1 = c.c1"),
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) ::
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None):: Nil
+ )
+ verifyJoinHint(
+ sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
+ "where a.a1 = b.b1 and b.b1 = c.c1"),
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) ::
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))):: Nil
+ )
+
+ verifyJoinHint(
+ df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
+ .join(df3, 'b1 === 'c1 && 'a1 < 10),
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) ::
+ JoinHint.NONE:: Nil
+ )
+
+ verifyJoinHint(
+ df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
+ .join(df3, 'b1 === 'c1 && 'a1 < 10)
+ .join(df, 'b1 === 'id),
+ JoinHint.NONE ::
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) ::
+ JoinHint.NONE:: Nil
+ )
+ }
+ }
+
+ test("intersect/except") {
+ val dfSub = spark.range(2)
+ verifyJoinHint(
+ df.hint("broadcast").except(dfSub).join(df, "id"),
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) ::
+ JoinHint.NONE :: Nil
+ )
+ verifyJoinHint(
+ df.join(df.hint("broadcast").intersect(dfSub), "id"),
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) ::
+ JoinHint.NONE :: Nil
+ )
+ }
+
+ test("hint merge") {
+ verifyJoinHint(
+ df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"),
+ JoinHint(
+ Some(HintInfo(broadcast = true)),
+ None) :: Nil
+ )
+ verifyJoinHint(
+ df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"),
+ JoinHint(
+ None,
+ Some(HintInfo(broadcast = true))) :: Nil
+ )
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index 02dc32d5f90ba..99842680cedfe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -237,8 +237,7 @@ class StatisticsCollectionSuite extends
StatisticsCollectionTestBase with Shared
)
numbers.foreach { case (input, (expectedSize, expectedRows)) =>
val stats = Statistics(sizeInBytes = input, rowCount = Some(input))
- val expectedString = s"sizeInBytes=$expectedSize,
rowCount=$expectedRows," +
- s" hints=none"
+ val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows"
assert(stats.simpleString == expectedString)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 42dd0024b2582..f238148e61c39 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -203,7 +203,7 @@ class BroadcastJoinSuite extends QueryTest with
SQLTestUtils {
}
test("broadcast hint in SQL") {
- import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join}
+ import org.apache.spark.sql.catalyst.plans.logical.Join
spark.range(10).createOrReplaceTempView("t")
spark.range(10).createOrReplaceTempView("u")
@@ -216,12 +216,12 @@ class BroadcastJoinSuite extends QueryTest with
SQLTestUtils {
val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id =
u.id").queryExecution
.optimizedPlan
- assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
- assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
- assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
- assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
- assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
- assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
+ assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast)
+ assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty)
+ assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty)
+ assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast)
+ assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty)
+ assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 22279a3a43eff..771a9730247af 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan,
SparkPlanTest}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.internal.SQLConf
@@ -85,7 +85,8 @@ class ExistenceJoinSuite extends SparkPlanTest with
SharedSQLContext {
expectedAnswer: Seq[Row]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
- val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner,
Some(condition))
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
+ Inner, Some(condition), JoinHint.NONE)
ExtractEquiJoinKeys.unapply(join)
}
@@ -102,7 +103,7 @@ class ExistenceJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
test(s"$testName using ShuffledHashJoin") {
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right:
SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
@@ -121,7 +122,7 @@ class ExistenceJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") {
_ =>
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right:
SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
@@ -140,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
test(s"$testName using SortMergeJoin") {
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right:
SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index f5edd6bbd5e69..f99a278bb2427 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.internal.SQLConf
@@ -80,7 +80,8 @@ class InnerJoinSuite extends SparkPlanTest with
SharedSQLContext {
expectedAnswer: Seq[Product]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
- val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner,
Some(condition()))
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
+ Inner, Some(condition()), JoinHint.NONE)
ExtractEquiJoinKeys.unapply(join)
}
@@ -128,7 +129,7 @@ class InnerJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin
(build=left)") { _ =>
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan:
SparkPlan) =>
makeBroadcastHashJoin(
@@ -140,7 +141,7 @@ class InnerJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin
(build=right)") { _ =>
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan:
SparkPlan) =>
makeBroadcastHashJoin(
@@ -152,7 +153,7 @@ class InnerJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
test(s"$testName using ShuffledHashJoin (build=left)") {
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan:
SparkPlan) =>
makeShuffledHashJoin(
@@ -164,7 +165,7 @@ class InnerJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
test(s"$testName using ShuffledHashJoin (build=right)") {
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan:
SparkPlan) =>
makeShuffledHashJoin(
@@ -176,7 +177,7 @@ class InnerJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan:
SparkPlan) =>
makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan,
rightPlan),
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 513248dae48be..1f04fcf6ca451 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.internal.SQLConf
@@ -72,13 +72,14 @@ class OuterJoinSuite extends SparkPlanTest with
SharedSQLContext {
expectedAnswer: Seq[Product]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
- val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner,
Some(condition))
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
+ Inner, Some(condition), JoinHint.NONE)
ExtractEquiJoinKeys.unapply(join)
}
if (joinType != FullOuter) {
test(s"$testName using ShuffledHashJoin") {
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val buildSide = if (joinType == LeftOuter) BuildRight else
BuildLeft
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right:
SparkPlan) =>
@@ -99,7 +100,7 @@ class OuterJoinSuite extends SparkPlanTest with
SharedSQLContext {
case RightOuter => BuildLeft
case _ => fail(s"Unsupported join type $joinType")
}
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right:
SparkPlan) =>
BroadcastHashJoinExec(
@@ -112,7 +113,7 @@ class OuterJoinSuite extends SparkPlanTest with
SharedSQLContext {
}
test(s"$testName using SortMergeJoin") {
- extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys,
boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right:
SparkPlan) =>
EnsureRequirements(spark.sessionState.conf).apply(
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]