asfgit closed pull request #23036: [SPARK-26065][SQL] Change query hint from a 
`LogicalPlan` to a field
URL: https://github.com/apache/spark/pull/23036
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 198645d875c47..2aa0f2117364c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -943,7 +943,7 @@ class Analyzer(
         failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
 
       // To resolve duplicate expression IDs for Join and Intersect
-      case j @ Join(left, right, _, _) if !j.duplicateResolved =>
+      case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
         j.copy(right = dedupRight(left, right))
       case i @ Intersect(left, right, _) if !i.duplicateResolved =>
         i.copy(right = dedupRight(left, right))
@@ -2249,13 +2249,14 @@ class Analyzer(
    */
   object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
     override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUp {
-      case j @ Join(left, right, UsingJoin(joinType, usingCols), _)
+      case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint)
           if left.resolved && right.resolved && j.duplicateResolved =>
-        commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
-      case j @ Join(left, right, NaturalJoin(joinType), condition) if 
j.resolvedExceptNatural =>
+        commonNaturalJoinProcessing(left, right, joinType, usingCols, None, 
hint)
+      case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
+          if j.resolvedExceptNatural =>
         // find common column names from both sides
         val joinNames = 
left.output.map(_.name).intersect(right.output.map(_.name))
-        commonNaturalJoinProcessing(left, right, joinType, joinNames, 
condition)
+        commonNaturalJoinProcessing(left, right, joinType, joinNames, 
condition, hint)
     }
   }
 
@@ -2360,7 +2361,8 @@ class Analyzer(
       right: LogicalPlan,
       joinType: JoinType,
       joinNames: Seq[String],
-      condition: Option[Expression]) = {
+      condition: Option[Expression],
+      hint: JoinHint) = {
     val leftKeys = joinNames.map { keyName =>
       left.output.find(attr => resolver(attr.name, keyName)).getOrElse {
         throw new AnalysisException(s"USING column `$keyName` cannot be 
resolved on the left " +
@@ -2401,7 +2403,7 @@ class Analyzer(
         sys.error("Unsupported natural join type " + joinType)
     }
     // use Project to trim unnecessary fields
-    Project(projectList, Join(left, right, joinType, newCondition))
+    Project(projectList, Join(left, right, joinType, newCondition, hint))
   }
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c28a97839fe49..18c40b370cb5f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -172,7 +172,7 @@ trait CheckAnalysis extends PredicateHelper {
             failAnalysis("Null-aware predicate sub-queries cannot be used in 
nested " +
               s"conditions: $condition")
 
-          case j @ Join(_, _, _, Some(condition)) if condition.dataType != 
BooleanType =>
+          case j @ Join(_, _, _, Some(condition), _) if condition.dataType != 
BooleanType =>
             failAnalysis(
               s"join condition '${condition.sql}' " +
                 s"of type ${condition.dataType.catalogString} is not a 
boolean.")
@@ -609,7 +609,7 @@ trait CheckAnalysis extends PredicateHelper {
         failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
 
       // Join can host correlated expressions.
-      case j @ Join(left, right, joinType, _) =>
+      case j @ Join(left, right, joinType, _, _) =>
         joinType match {
           // Inner join, like Filter, can be anywhere.
           case _: InnerLike =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
index 7a0aa08289efa..76733dd6dac3c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
@@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with 
Logging {
    */
   def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = {
     plan match {
-      case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) =>
+      case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) =>
         (leftKeys ++ rightKeys).exists {
           case a: AttributeReference => 
a.metadata.contains(EventTimeWatermark.delayKey)
           case _ => false
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index cff4cee09427f..41ba6d34b5499 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -229,7 +229,7 @@ object UnsupportedOperationChecker {
           throwError("dropDuplicates is not supported after aggregation on a " 
+
             "streaming DataFrame/Dataset")
 
-        case Join(left, right, joinType, condition) =>
+        case Join(left, right, joinType, condition, _) =>
 
           joinType match {
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 151481c80ee96..846ee3b386527 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -325,7 +325,7 @@ package object dsl {
         otherPlan: LogicalPlan,
         joinType: JoinType = Inner,
         condition: Option[Expression] = None): LogicalPlan =
-        Join(logicalPlan, otherPlan, joinType, condition)
+        Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE)
 
       def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: 
Encoder](
           otherPlan: LogicalPlan,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index 01634a9d852c6..743d3ce944fe2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{And, Attribute, 
AttributeSet, Expression, PredicateHelper}
 import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType}
-import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, 
LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
 
@@ -31,6 +31,40 @@ import org.apache.spark.sql.internal.SQLConf
  * Cost-based join reorder.
  * We may have several join reorder algorithms in the future. This class is 
the entry of these
  * algorithms, and chooses which one to use.
+ *
+ * Note that join strategy hints, e.g. the broadcast hint, do not interfere 
with the reordering.
+ * Such hints will be applied on the equivalent counterparts (i.e., join 
between the same relations
+ * regardless of the join order) of the original nodes after reordering.
+ * For example, the plan before reordering is like:
+ *
+ *            Join
+ *            /   \
+ *          Hint1 t4
+ *          /
+ *        Join
+ *        / \
+ *      Join t3
+ *      /   \
+ *    Hint2 t2
+ *    /
+ *   t1
+ *
+ * The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) 
JOIN t4", and after
+ * reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the 
new plan will be like:
+ *
+ *            Join
+ *            /   \
+ *          Hint1 t4
+ *          /
+ *        Join
+ *        / \
+ *      Join t2
+ *      / \
+ *    t1  t3
+ *
+ * "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the 
original hinted node,
+ * "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan 
since there is no
+ * equivalent node to "t1 JOIN t2".
  */
 object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
 
@@ -40,24 +74,30 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with 
PredicateHelper {
     if (!conf.cboEnabled || !conf.joinReorderEnabled) {
       plan
     } else {
+      // Use a map to track the hints on the join items.
+      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
       val result = plan transformDown {
         // Start reordering with a joinable item, which is an InnerLike join 
with conditions.
-        case j @ Join(_, _, _: InnerLike, Some(cond)) =>
-          reorder(j, j.output)
-        case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond)))
+        case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
+          reorder(j, j.output, hintMap)
+        case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _))
           if projectList.forall(_.isInstanceOf[Attribute]) =>
-          reorder(p, p.output)
+          reorder(p, p.output, hintMap)
       }
-
-      // After reordering is finished, convert OrderedJoin back to Join
-      result transformDown {
-        case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond)
+      // After reordering is finished, convert OrderedJoin back to Join.
+      result transform {
+        case OrderedJoin(left, right, jt, cond) =>
+          val joinHint = JoinHint(hintMap.get(left.outputSet), 
hintMap.get(right.outputSet))
+          Join(left, right, jt, cond, joinHint)
       }
     }
   }
 
-  private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan 
= {
-    val (items, conditions) = extractInnerJoins(plan)
+  private def reorder(
+      plan: LogicalPlan,
+      output: Seq[Attribute],
+      hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = {
+    val (items, conditions) = extractInnerJoins(plan, hintMap)
     val result =
       // Do reordering if the number of items is appropriate and join 
conditions exist.
       // We also need to check if costs of all items can be evaluated.
@@ -75,27 +115,31 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with 
PredicateHelper {
    * Extracts items of consecutive inner joins and join conditions.
    * This method works for bushy trees and left/right deep trees.
    */
-  private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], 
Set[Expression]) = {
+  private def extractInnerJoins(
+      plan: LogicalPlan,
+      hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], 
Set[Expression]) = {
     plan match {
-      case Join(left, right, _: InnerLike, Some(cond)) =>
-        val (leftPlans, leftConditions) = extractInnerJoins(left)
-        val (rightPlans, rightConditions) = extractInnerJoins(right)
+      case Join(left, right, _: InnerLike, Some(cond), hint) =>
+        hint.leftHint.foreach(hintMap.put(left.outputSet, _))
+        hint.rightHint.foreach(hintMap.put(right.outputSet, _))
+        val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap)
+        val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap)
         (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
           leftConditions ++ rightConditions)
-      case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond)))
+      case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
         if projectList.forall(_.isInstanceOf[Attribute]) =>
-        extractInnerJoins(j)
+        extractInnerJoins(j, hintMap)
       case _ =>
         (Seq(plan), Set())
     }
   }
 
   private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan 
match {
-    case j @ Join(left, right, jt: InnerLike, Some(cond)) =>
+    case j @ Join(left, right, jt: InnerLike, Some(cond), _) =>
       val replacedLeft = replaceWithOrderedJoin(left)
       val replacedRight = replaceWithOrderedJoin(right)
       OrderedJoin(replacedLeft, replacedRight, jt, Some(cond))
-    case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) =>
+    case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) 
=>
       p.copy(child = replaceWithOrderedJoin(j))
     case _ =>
       plan
@@ -295,7 +339,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
     } else {
       (otherPlan, onePlan)
     }
-    val newJoin = Join(left, right, Inner, joinConds.reduceOption(And))
+    val newJoin = Join(left, right, Inner, joinConds.reduceOption(And), 
JoinHint.NONE)
     val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ 
otherJoinPlan.joinConds
     val remainingConds = conditions -- collectedJoinConds
     val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ 
topOutput
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
new file mode 100644
index 0000000000000..bbe4eee4b4326
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Replaces [[ResolvedHint]] operators from the plan. Move the [[HintInfo]] to 
associated [[Join]]
+ * operators, otherwise remove it if no [[Join]] operator is matched.
+ */
+object EliminateResolvedHint extends Rule[LogicalPlan] {
+  // This is also called in the beginning of the optimization phase, and as a 
result
+  // is using transformUp rather than resolveOperators.
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    val pulledUp = plan transformUp {
+      case j: Join =>
+        val leftHint = mergeHints(collectHints(j.left))
+        val rightHint = mergeHints(collectHints(j.right))
+        j.copy(hint = JoinHint(leftHint, rightHint))
+    }
+    pulledUp.transform {
+      case h: ResolvedHint => h.child
+    }
+  }
+
+  private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = {
+    hints.reduceOption((h1, h2) => HintInfo(
+      broadcast = h1.broadcast || h2.broadcast))
+  }
+
+  private def collectHints(plan: LogicalPlan): Seq[HintInfo] = {
+    plan match {
+      case h: ResolvedHint => collectHints(h.child) :+ h.hints
+      case u: UnaryNode => collectHints(u.child)
+      // TODO revisit this logic:
+      // except and intersect are semi/anti-joins which won't return more data 
then
+      // their left argument, so the broadcast hint should be propagated here
+      case i: Intersect => collectHints(i.left)
+      case e: Except => collectHints(e.left)
+      case _ => Seq.empty
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 44d5543114902..06f908281dd3c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -115,6 +115,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
     // However, because we also use the analyzer to canonicalized queries (for 
view definition),
     // we do not eliminate subqueries or compute current time in the analyzer.
     Batch("Finish Analysis", Once,
+      EliminateResolvedHint,
       EliminateSubqueryAliases,
       EliminateView,
       ReplaceExpressions,
@@ -192,6 +193,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
    */
   def nonExcludableRules: Seq[String] =
     EliminateDistinct.ruleName ::
+      EliminateResolvedHint.ruleName ::
       EliminateSubqueryAliases.ruleName ::
       EliminateView.ruleName ::
       ReplaceExpressions.ruleName ::
@@ -356,7 +358,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
       // not allowed to use the same attributes. We use a blacklist to prevent 
us from creating a
       // situation in which this happens; the rule will only remove an alias 
if its child
       // attribute is not on the black list.
-      case Join(left, right, joinType, condition) =>
+      case Join(left, right, joinType, condition, hint) =>
         val newLeft = removeRedundantAliases(left, blacklist ++ 
right.outputSet)
         val newRight = removeRedundantAliases(right, blacklist ++ 
newLeft.outputSet)
         val mapping = AttributeMap(
@@ -365,7 +367,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
         val newCondition = condition.map(_.transform {
           case a: Attribute => mapping.getOrElse(a, a)
         })
-        Join(newLeft, newRight, joinType, newCondition)
+        Join(newLeft, newRight, joinType, newCondition, hint)
 
       case _ =>
         // Remove redundant aliases in the subtree(s).
@@ -460,7 +462,7 @@ object LimitPushDown extends Rule[LogicalPlan] {
     // on both sides if it is applied multiple times. Therefore:
     //   - If one side is already limited, stack another limit on top if the 
new limit is smaller.
     //     The redundant limit will be collapsed by the CombineLimits rule.
-    case LocalLimit(exp, join @ Join(left, right, joinType, _)) =>
+    case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) =>
       val newJoin = joinType match {
         case RightOuter => join.copy(right = maybePushLocalLimit(exp, right))
         case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left))
@@ -578,7 +580,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
       p.copy(child = g.copy(child = newChild, unrequiredChildIndex = 
unrequiredIndices))
 
     // Eliminate unneeded attributes from right side of a Left Existence Join.
-    case j @ Join(_, right, LeftExistence(_), _) =>
+    case j @ Join(_, right, LeftExistence(_), _, _) =>
       j.copy(right = prunedChild(right, j.references))
 
     // all the columns will be used to compare, so we can't prune them
@@ -792,7 +794,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
         filter
       }
 
-    case join @ Join(left, right, joinType, conditionOpt) =>
+    case join @ Join(left, right, joinType, conditionOpt, _) =>
       joinType match {
         // For inner join, we can infer additional filters for both sides. 
LeftSemi is kind of an
         // inner join, it just drops the right side in the final output.
@@ -919,7 +921,6 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] {
   def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
     case p: Project => p.projectList.forall(_.deterministic)
     case f: Filter => f.condition.deterministic
-    case _: ResolvedHint => true
     case _ => false
   }
 }
@@ -1094,7 +1095,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with 
PredicateHelper {
     // Note that some operators (e.g. project, aggregate, union) are being 
handled separately
     // (earlier in this rule).
     case _: AppendColumns => true
-    case _: ResolvedHint => true
     case _: Distinct => true
     case _: Generate => true
     case _: Pivot => true
@@ -1179,7 +1179,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     // push the where condition down into join filter
-    case f @ Filter(filterCondition, Join(left, right, joinType, 
joinCondition)) =>
+    case f @ Filter(filterCondition, Join(left, right, joinType, 
joinCondition, hint)) =>
       val (leftFilterConditions, rightFilterConditions, commonFilterCondition) 
=
         split(splitConjunctivePredicates(filterCondition), left, right)
       joinType match {
@@ -1193,7 +1193,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
             commonFilterCondition.partition(canEvaluateWithinJoin)
           val newJoinCond = (newJoinConditions ++ 
joinCondition).reduceLeftOption(And)
 
-          val join = Join(newLeft, newRight, joinType, newJoinCond)
+          val join = Join(newLeft, newRight, joinType, newJoinCond, hint)
           if (others.nonEmpty) {
             Filter(others.reduceLeft(And), join)
           } else {
@@ -1205,7 +1205,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val newJoinCond = joinCondition
-          val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
+          val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, hint)
 
           (leftFilterConditions ++ commonFilterCondition).
             reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
@@ -1215,7 +1215,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
             reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
           val newRight = right
           val newJoinCond = joinCondition
-          val newJoin = Join(newLeft, newRight, joinType, newJoinCond)
+          val newJoin = Join(newLeft, newRight, joinType, newJoinCond, hint)
 
           (rightFilterConditions ++ commonFilterCondition).
             reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
@@ -1225,7 +1225,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
       }
 
     // push down the join filter into sub query scanning if applicable
-    case j @ Join(left, right, joinType, joinCondition) =>
+    case j @ Join(left, right, joinType, joinCondition, hint) =>
       val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
         split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), 
left, right)
 
@@ -1238,7 +1238,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val newJoinCond = commonJoinCondition.reduceLeftOption(And)
 
-          Join(newLeft, newRight, joinType, newJoinCond)
+          Join(newLeft, newRight, joinType, newJoinCond, hint)
         case RightOuter =>
           // push down the left side only join filter for left side sub query
           val newLeft = leftJoinConditions.
@@ -1246,7 +1246,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
           val newRight = right
           val newJoinCond = (rightJoinConditions ++ 
commonJoinCondition).reduceLeftOption(And)
 
-          Join(newLeft, newRight, RightOuter, newJoinCond)
+          Join(newLeft, newRight, RightOuter, newJoinCond, hint)
         case LeftOuter | LeftAnti | ExistenceJoin(_) =>
           // push down the right side only join filter for right sub query
           val newLeft = left
@@ -1254,7 +1254,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val newJoinCond = (leftJoinConditions ++ 
commonJoinCondition).reduceLeftOption(And)
 
-          Join(newLeft, newRight, joinType, newJoinCond)
+          Join(newLeft, newRight, joinType, newJoinCond, hint)
         case FullOuter => j
         case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
         case UsingJoin(_, _) => sys.error("Untransformed Using join node")
@@ -1310,7 +1310,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] 
with PredicateHelper {
     if (SQLConf.get.crossJoinEnabled) {
       plan
     } else plan transform {
-      case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _)
+      case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, 
_, _)
         if isCartesianProduct(j) =>
           throw new AnalysisException(
             s"""Detected implicit cartesian product for ${j.joinType.sql} join 
between logical plans
@@ -1449,7 +1449,7 @@ object ReplaceIntersectWithSemiJoin extends 
Rule[LogicalPlan] {
     case Intersect(left, right, false) =>
       assert(left.output.size == right.output.size)
       val joinCond = left.output.zip(right.output).map { case (l, r) => 
EqualNullSafe(l, r) }
-      Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
+      Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And), 
JoinHint.NONE))
   }
 }
 
@@ -1470,7 +1470,7 @@ object ReplaceExceptWithAntiJoin extends 
Rule[LogicalPlan] {
     case Except(left, right, false) =>
       assert(left.output.size == right.output.size)
       val joinCond = left.output.zip(right.output).map { case (l, r) => 
EqualNullSafe(l, r) }
-      Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And)))
+      Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And), 
JoinHint.NONE))
   }
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index c3fdb924243df..b19e13870aa65 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -56,7 +56,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with 
PredicateHelper wit
     // Joins on empty LocalRelations generated from streaming sources are not 
eliminated
     // as stateful streaming joins need to perform other state management 
operations other than
     // just processing the input data.
-    case p @ Join(_, _, joinType, _)
+    case p @ Join(_, _, joinType, _, _)
         if !p.children.exists(_.isStreaming) =>
       val isLeftEmpty = isEmptyLocalRelation(p.left)
       val isRightEmpty = isEmptyLocalRelation(p.right)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
index 72a60f692ac78..689915a985343 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -52,7 +52,7 @@ object ReplaceNullWithFalseInPredicate extends 
Rule[LogicalPlan] {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
-    case j @ Join(_, _, _, Some(cond)) => j.copy(condition = 
Some(replaceNullWithFalse(cond)))
+    case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = 
Some(replaceNullWithFalse(cond)))
     case p: LogicalPlan => p transformExpressions {
       case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
       case cw @ CaseWhen(branches, _) =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 468a950fb1087..39709529c00d3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -600,7 +600,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
         // propagating the foldable expressions.
         // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
         // of outer join.
-        case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty =>
+        case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
           val newJoin = j.transformExpressions(replaceFoldable)
           val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
             case _: InnerLike | LeftExistence(_) => Nil
@@ -648,7 +648,6 @@ object FoldablePropagation extends Rule[LogicalPlan] {
     case _: Distinct => true
     case _: AppendColumns => true
     case _: AppendColumnsWithObject => true
-    case _: ResolvedHint => true
     case _: RepartitionByExpression => true
     case _: Repartition => true
     case _: Sort => true
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 0b6471289a471..82aefca8a1af6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -43,10 +43,13 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
    *
    * @param input a list of LogicalPlans to inner join and the type of inner 
join.
    * @param conditions a list of condition for join.
+   * @param hintMap a map of relation output attribute sets to their 
corresponding hints.
    */
   @tailrec
-  final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], 
conditions: Seq[Expression])
-    : LogicalPlan = {
+  final def createOrderedJoin(
+      input: Seq[(LogicalPlan, InnerLike)],
+      conditions: Seq[Expression],
+      hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
       val (joinConditions, others) = 
conditions.partition(canEvaluateWithinJoin)
@@ -55,7 +58,8 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
         case (Inner, Inner) => Inner
         case (_, _) => Cross
       }
-      val join = Join(left, right, innerJoinType, 
joinConditions.reduceLeftOption(And))
+      val join = Join(left, right, innerJoinType, 
joinConditions.reduceLeftOption(And),
+        JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
       if (others.nonEmpty) {
         Filter(others.reduceLeft(And), join)
       } else {
@@ -78,26 +82,27 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
         e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
-      val joined = Join(left, right, innerJoinType, 
joinConditions.reduceLeftOption(And))
+      val joined = Join(left, right, innerJoinType, 
joinConditions.reduceLeftOption(And),
+        JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
 
       // should not have reference to same logical plan
-      createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), 
others)
+      createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), 
others, hintMap)
     }
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case p @ ExtractFiltersAndInnerJoins(input, conditions)
+    case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap)
         if input.size > 2 && conditions.nonEmpty =>
       val reordered = if (SQLConf.get.starSchemaDetection && 
!SQLConf.get.cboEnabled) {
         val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, 
conditions)
         if (starJoinPlan.nonEmpty) {
           val rest = input.filterNot(starJoinPlan.contains(_))
-          createOrderedJoin(starJoinPlan ++ rest, conditions)
+          createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap)
         } else {
-          createOrderedJoin(input, conditions)
+          createOrderedJoin(input, conditions, hintMap)
         }
       } else {
-        createOrderedJoin(input, conditions)
+        createOrderedJoin(input, conditions, hintMap)
       }
 
       if (p.sameOutput(reordered)) {
@@ -156,7 +161,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with 
PredicateHelper {
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | 
FullOuter, _)) =>
+    case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | 
FullOuter, _, _)) =>
       val newJoinType = buildNewJoinType(f, j)
       if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType 
= newJoinType))
   }
@@ -176,7 +181,7 @@ object PullOutPythonUDFInJoinCondition extends 
Rule[LogicalPlan] with PredicateH
   }
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, 
j) =>
+    case j @ Join(_, _, joinType, Some(cond), _) if 
hasUnevaluablePythonUDF(cond, j) =>
       if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
         // The current strategy only support InnerLike and LeftSemi join 
because for other type,
         // it breaks SQL semantic if we run the join condition as a filter 
after join. If we pass
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 34840c6c977a6..e78ed1c3c5d94 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -51,7 +51,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
       condition: Option[Expression]): Join = {
     // Deduplicate conflicting attributes if any.
     val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, 
condition)
-    Join(outerPlan, dedupSubplan, joinType, condition)
+    Join(outerPlan, dedupSubplan, joinType, condition, JoinHint.NONE)
   }
 
   private def dedupSubqueryOnSelfJoin(
@@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
           val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
           val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ 
conditions, p)
-          Join(outerPlan, newSub, LeftSemi, joinCond)
+          Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE)
         case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
           // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
           // Construct the condition. A NULL in one of the conditions is 
regarded as a positive
@@ -142,7 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           // will have the final conditions in the LEFT ANTI as
           // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 
1
           val finalJoinCond = (nullAwareJoinConds ++ 
conditions).reduceLeft(And)
-          Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
+          Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond), 
JoinHint.NONE)
         case (p, predicate) =>
           val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
           Project(p.output, Filter(newCond.get, inputPlan))
@@ -172,7 +172,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
           val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
           val newConditions = (inConditions ++ 
conditions).reduceLeftOption(And)
-          newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
+          newPlan = Join(newPlan, newSub, ExistenceJoin(exists), 
newConditions, JoinHint.NONE)
           exists
       }
     }
@@ -450,7 +450,7 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] {
           // CASE 1: Subquery guaranteed not to have the COUNT bug
           Project(
             currentChild.output :+ origOutput,
-            Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+            Join(currentChild, query, LeftOuter, conditions.reduceOption(And), 
JoinHint.NONE))
         } else {
           // Subquery might have the COUNT bug. Add appropriate corrections.
           val (topPart, havingNode, aggNode) = splitSubquery(query)
@@ -477,7 +477,7 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] {
                     aggValRef), origOutput.name)(exprId = origOutput.exprId),
               Join(currentChild,
                 Project(query.output :+ alwaysTrueExpr, query),
-                LeftOuter, conditions.reduceOption(And)))
+                LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
 
           } else {
             // CASE 3: Subquery with HAVING clause. Pull the HAVING clause 
above the join.
@@ -507,7 +507,7 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] {
               currentChild.output :+ caseExpr,
               Join(currentChild,
                 Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
-                LeftOuter, conditions.reduceOption(And)))
+                LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
 
           }
         }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 8959f78b656d2..a27c6d3c3671c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -515,7 +515,7 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
   override def visitFromClause(ctx: FromClauseContext): LogicalPlan = 
withOrigin(ctx) {
     val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, 
relation) =>
       val right = plan(relation.relationPrimary)
-      val join = right.optionalMap(left)(Join(_, _, Inner, None))
+      val join = right.optionalMap(left)(Join(_, _, Inner, None, 
JoinHint.NONE))
       withJoinRelations(join, relation)
     }
     if (ctx.pivotClause() != null) {
@@ -727,7 +727,7 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
           case None =>
             (baseJoinType, None)
         }
-        Join(left, plan(join.right), joinType, condition)
+        Join(left, plan(join.right), joinType, condition, JoinHint.NONE)
       }
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 84be677e438a6..dfc3b2d22129d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.planning
 
+import scala.collection.mutable
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
@@ -98,12 +100,13 @@ object PhysicalOperation extends PredicateHelper {
  * value).
  */
 object ExtractEquiJoinKeys extends Logging with PredicateHelper {
-  /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
+  /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild, 
joinHint) */
   type ReturnType =
-    (JoinType, Seq[Expression], Seq[Expression], Option[Expression], 
LogicalPlan, LogicalPlan)
+    (JoinType, Seq[Expression], Seq[Expression],
+      Option[Expression], LogicalPlan, LogicalPlan, JoinHint)
 
   def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
-    case join @ Join(left, right, joinType, condition) =>
+    case join @ Join(left, right, joinType, condition, hint) =>
       logDebug(s"Considering join on: $condition")
       // Find equi-join predicates that can be evaluated before the join, and 
thus can be used
       // as join keys.
@@ -133,7 +136,7 @@ object ExtractEquiJoinKeys extends Logging with 
PredicateHelper {
       if (joinKeys.nonEmpty) {
         val (leftKeys, rightKeys) = joinKeys.unzip
         logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
-        Some((joinType, leftKeys, rightKeys, 
otherPredicates.reduceOption(And), left, right))
+        Some((joinType, leftKeys, rightKeys, 
otherPredicates.reduceOption(And), left, right, hint))
       } else {
         None
       }
@@ -164,25 +167,35 @@ object ExtractFiltersAndInnerJoins extends 
PredicateHelper {
    * was involved in an explicit cross join. Also returns the entire list of 
join conditions for
    * the left-deep tree.
    */
-  def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
+  def flattenJoin(
+      plan: LogicalPlan,
+      hintMap: mutable.HashMap[AttributeSet, HintInfo],
+      parentJoinType: InnerLike = Inner)
       : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
-    case Join(left, right, joinType: InnerLike, cond) =>
-      val (plans, conditions) = flattenJoin(left, joinType)
+    case Join(left, right, joinType: InnerLike, cond, hint) =>
+      val (plans, conditions) = flattenJoin(left, hintMap, joinType)
+      hint.leftHint.map(hintMap.put(left.outputSet, _))
+      hint.rightHint.map(hintMap.put(right.outputSet, _))
       (plans ++ Seq((right, joinType)), conditions ++
         cond.toSeq.flatMap(splitConjunctivePredicates))
-    case Filter(filterCondition, j @ Join(left, right, _: InnerLike, 
joinCondition)) =>
-      val (plans, conditions) = flattenJoin(j)
+    case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) =>
+      val (plans, conditions) = flattenJoin(j, hintMap)
       (plans, conditions ++ splitConjunctivePredicates(filterCondition))
 
     case _ => (Seq((plan, parentJoinType)), Seq.empty)
   }
 
-  def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], 
Seq[Expression])]
+  def unapply(plan: LogicalPlan)
+      : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], 
Map[AttributeSet, HintInfo])]
       = plan match {
-    case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _)) =>
-      Some(flattenJoin(f))
-    case j @ Join(_, _, joinType, _) =>
-      Some(flattenJoin(j))
+    case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, 
_)) =>
+      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
+      val flattened = flattenJoin(f, hintMap)
+      Some((flattened._1, flattened._2, hintMap.toMap))
+    case j @ Join(_, _, joinType, _, _) =>
+      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
+      val flattened = flattenJoin(j, hintMap)
+      Some((flattened._1, flattened._2, hintMap.toMap))
     case _ => None
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
index 2c248d74869ce..18baced8f3d61 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala
@@ -37,7 +37,6 @@ trait LogicalPlanVisitor[T] {
     case p: Project => visitProject(p)
     case p: Repartition => visitRepartition(p)
     case p: RepartitionByExpression => visitRepartitionByExpr(p)
-    case p: ResolvedHint => visitHint(p)
     case p: Sample => visitSample(p)
     case p: ScriptTransformation => visitScriptTransform(p)
     case p: Union => visitUnion(p)
@@ -61,8 +60,6 @@ trait LogicalPlanVisitor[T] {
 
   def visitGlobalLimit(p: GlobalLimit): T
 
-  def visitHint(p: ResolvedHint): T
-
   def visitIntersect(p: Intersect): T
 
   def visitJoin(p: Join): T
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index b3a48860aa63b..5a388117a6c0a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -52,13 +52,11 @@ import org.apache.spark.util.Utils
  *                    defaults to the product of children's `sizeInBytes`.
  * @param rowCount Estimated number of rows.
  * @param attributeStats Statistics for Attributes.
- * @param hints Query hints.
  */
 case class Statistics(
     sizeInBytes: BigInt,
     rowCount: Option[BigInt] = None,
-    attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil),
-    hints: HintInfo = HintInfo()) {
+    attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil)) {
 
   override def toString: String = "Statistics(" + simpleString + ")"
 
@@ -70,8 +68,7 @@ case class Statistics(
         s"rowCount=${BigDecimal(rowCount.get, new MathContext(3, 
RoundingMode.HALF_UP)).toString()}"
       } else {
         ""
-      },
-      s"hints=$hints"
+      }
     ).filter(_.nonEmpty).mkString(", ")
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index d8b3a4af4f7bf..639d68f4ecd76 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -288,7 +288,8 @@ case class Join(
     left: LogicalPlan,
     right: LogicalPlan,
     joinType: JoinType,
-    condition: Option[Expression])
+    condition: Option[Expression],
+    hint: JoinHint)
   extends BinaryNode with PredicateHelper {
 
   override def output: Seq[Attribute] = {
@@ -350,6 +351,17 @@ case class Join(
     case UsingJoin(_, _) => false
     case _ => resolvedExceptNatural
   }
+
+  // Ignore hint for canonicalization
+  protected override def doCanonicalize(): LogicalPlan =
+    super.doCanonicalize().asInstanceOf[Join].copy(hint = JoinHint.NONE)
+
+  // Do not include an empty join hint in string description
+  protected override def stringArgs: Iterator[Any] = super.stringArgs.filter { 
e =>
+    (!e.isInstanceOf[JoinHint]
+      || e.asInstanceOf[JoinHint].leftHint.isDefined
+      || e.asInstanceOf[JoinHint].rightHint.isDefined)
+  }
 }
 
 /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index cbb626590d1d7..b2ba725e9d44f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -35,6 +35,7 @@ case class UnresolvedHint(name: String, parameters: Seq[Any], 
child: LogicalPlan
 
 /**
  * A resolved hint node. The analyzer should convert all [[UnresolvedHint]] 
into [[ResolvedHint]].
+ * This node will be eliminated before optimization starts.
  */
 case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
   extends UnaryNode {
@@ -44,11 +45,31 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo 
= HintInfo())
   override def doCanonicalize(): LogicalPlan = child.canonicalized
 }
 
+/**
+ * Hint that is associated with a [[Join]] node, with [[HintInfo]] on its left 
child and on its
+ * right child respectively.
+ */
+case class JoinHint(leftHint: Option[HintInfo], rightHint: Option[HintInfo]) {
 
-case class HintInfo(broadcast: Boolean = false) {
+  override def toString: String = {
+    Seq(
+      leftHint.map("leftHint=" + _),
+      rightHint.map("rightHint=" + _))
+      .filter(_.isDefined).map(_.get).mkString(", ")
+  }
+}
 
-  /** Must be called when computing stats for a join operator to reset hints. 
*/
-  def resetForJoin(): HintInfo = copy(broadcast = false)
+object JoinHint {
+  val NONE = JoinHint(None, None)
+}
+
+/**
+ * The hint attributes to be applied on a specific node.
+ *
+ * @param broadcast If set to true, it indicates that the broadcast hash join 
is the preferred join
+ *                  strategy and the node with this hint is preferred to be 
the build side.
+ */
+case class HintInfo(broadcast: Boolean = false) {
 
   override def toString: String = {
     val hints = scala.collection.mutable.ArrayBuffer.empty[String]
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index 111c594a53e52..eb56ab43ea9d5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -56,8 +56,7 @@ object AggregateEstimation {
       Some(Statistics(
         sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
         rowCount = Some(outputRows),
-        attributeStats = outputAttrStats,
-        hints = childStats.hints))
+        attributeStats = outputAttrStats))
     } else {
       None
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
index b6c16079d1984..b8c652dc8f12e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala
@@ -47,8 +47,6 @@ object BasicStatsPlanVisitor extends 
LogicalPlanVisitor[Statistics] {
 
   override def visitGlobalLimit(p: GlobalLimit): Statistics = fallback(p)
 
-  override def visitHint(p: ResolvedHint): Statistics = fallback(p)
-
   override def visitIntersect(p: Intersect): Statistics = fallback(p)
 
   override def visitJoin(p: Join): Statistics = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index 2543e38a92c0a..19a0d1279cc32 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging {
     case _ if !rowCountsExist(join.left, join.right) =>
       None
 
-    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) =>
+    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) =>
       // 1. Compute join selectivity
       val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
       val (numInnerJoinedRows, keyStatsAfterJoin) = 
computeCardinalityAndStats(joinKeyPairs)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
index ee43f9126386b..da36db7ae1f5f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
@@ -44,7 +44,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends 
LogicalPlanVisitor[Statistics] {
     }
 
     // Don't propagate rowCount and attributeStats, since they are not 
estimated here.
-    Statistics(sizeInBytes = sizeInBytes, hints = p.child.stats.hints)
+    Statistics(sizeInBytes = sizeInBytes)
   }
 
   /**
@@ -60,8 +60,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends 
LogicalPlanVisitor[Statistics] {
     if (p.groupingExpressions.isEmpty) {
       Statistics(
         sizeInBytes = EstimationUtils.getOutputSize(p.output, outputRowCount = 
1),
-        rowCount = Some(1),
-        hints = p.child.stats.hints)
+        rowCount = Some(1))
     } else {
       visitUnaryNode(p)
     }
@@ -87,19 +86,15 @@ object SizeInBytesOnlyStatsPlanVisitor extends 
LogicalPlanVisitor[Statistics] {
     // Don't propagate column stats, because we don't know the distribution 
after limit
     Statistics(
       sizeInBytes = EstimationUtils.getOutputSize(p.output, rowCount, 
childStats.attributeStats),
-      rowCount = Some(rowCount),
-      hints = childStats.hints)
+      rowCount = Some(rowCount))
   }
 
-  override def visitHint(p: ResolvedHint): Statistics = 
p.child.stats.copy(hints = p.hints)
-
   override def visitIntersect(p: Intersect): Statistics = {
     val leftSize = p.left.stats.sizeInBytes
     val rightSize = p.right.stats.sizeInBytes
     val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
     Statistics(
-      sizeInBytes = sizeInBytes,
-      hints = p.left.stats.hints.resetForJoin())
+      sizeInBytes = sizeInBytes)
   }
 
   override def visitJoin(p: Join): Statistics = {
@@ -108,10 +103,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends 
LogicalPlanVisitor[Statistics] {
         // LeftSemi and LeftAnti won't ever be bigger than left
         p.left.stats
       case _ =>
-        // Make sure we don't propagate isBroadcastable in other joins, because
-        // they could explode the size.
-        val stats = default(p)
-        stats.copy(hints = stats.hints.resetForJoin())
+        default(p)
     }
   }
 
@@ -121,7 +113,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends 
LogicalPlanVisitor[Statistics] {
     if (limit == 0) {
       // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be 
zero
       // (product of children).
-      Statistics(sizeInBytes = 1, rowCount = Some(0), hints = childStats.hints)
+      Statistics(sizeInBytes = 1, rowCount = Some(0))
     } else {
       // The output row count of LocalLimit should be the sum of row counts 
from each partition.
       // However, since the number of partitions is not available here, we 
just use statistics of
@@ -147,7 +139,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends 
LogicalPlanVisitor[Statistics] {
     }
     val sampleRows = p.child.stats.rowCount.map(c => 
EstimationUtils.ceil(BigDecimal(c) * ratio))
     // Don't propagate column stats, because we don't know the distribution 
after a sample operation
-    Statistics(sizeInBytes, sampleRows, hints = p.child.stats.hints)
+    Statistics(sizeInBytes, sampleRows)
   }
 
   override def visitScriptTransform(p: ScriptTransformation): Statistics = 
default(p)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 117e96175e92a..129ce3b1105ee 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -443,7 +443,7 @@ class AnalysisErrorSuite extends AnalysisTest {
   }
 
   test("error test for self-join") {
-    val join = Join(testRelation, testRelation, Cross, None)
+    val join = Join(testRelation, testRelation, Cross, None, JoinHint.NONE)
     val error = intercept[AnalysisException] {
       SimpleAnalyzer.checkAnalysis(join)
     }
@@ -565,7 +565,8 @@ class AnalysisErrorSuite extends AnalysisTest {
           LocalRelation(b),
           Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
           LeftOuter,
-          Option(EqualTo(b, c)))),
+          Option(EqualTo(b, c)),
+          JoinHint.NONE)),
       LocalRelation(a))
     assertAnalysisError(plan1, "Accessing outer query column is not allowed 
in" :: Nil)
 
@@ -575,7 +576,8 @@ class AnalysisErrorSuite extends AnalysisTest {
           Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)),
           LocalRelation(b),
           RightOuter,
-          Option(EqualTo(b, c)))),
+          Option(EqualTo(b, c)),
+          JoinHint.NONE)),
       LocalRelation(a))
     assertAnalysisError(plan2, "Accessing outer query column is not allowed 
in" :: Nil)
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index da3ae72c3682a..982948483fa1c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -397,7 +397,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
         Join(
           Project(Seq($"x.key"), SubqueryAlias("x", input)),
           Project(Seq($"y.key"), SubqueryAlias("y", input)),
-          Cross, None))
+          Cross, None, JoinHint.NONE))
 
     assertAnalysisSuccess(query)
   }
@@ -578,7 +578,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       Seq(UnresolvedAttribute("a")), pythonUdf, output, project)
     val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
     val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
-    val join = Join(left, right, Inner, None)
+    val join = Join(left, right, Inner, None, JoinHint.NONE)
     assertAnalysisSuccess(
       Project(Seq(UnresolvedAttribute("temp0.a"), 
UnresolvedAttribute("temp1.a")), join))
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index bd66ee5355f45..563e8adf87edc 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -60,7 +60,7 @@ class ResolveHintsSuite extends AnalysisTest {
     checkAnalysis(
       UnresolvedHint("MAPJOIN", Seq("table", "table2"), 
table("table").join(table("table2"))),
       Join(ResolvedHint(testRelation, HintInfo(broadcast = true)),
-        ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None),
+        ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None, 
JoinHint.NONE),
       caseSensitive = false)
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 57195d5fda7c5..0cd6e092e2036 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -353,15 +353,15 @@ class ColumnPruningSuite extends PlanTest {
       Project(Seq($"x.key", $"y.key"),
         Join(
           SubqueryAlias("x", input),
-          ResolvedHint(SubqueryAlias("y", input)), Inner, None)).analyze
+          SubqueryAlias("y", input), Inner, None, JoinHint.NONE)).analyze
 
     val optimized = Optimize.execute(query)
 
     val expected =
       Join(
         Project(Seq($"x.key"), SubqueryAlias("x", input)),
-        ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
-        Inner, None).analyze
+        Project(Seq($"y.key"), SubqueryAlias("y", input)),
+        Inner, None, JoinHint.NONE).analyze
 
     comparePlans(optimized, expected)
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 82a10254d846d..cf4e9fcea2c6d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
@@ -822,19 +821,6 @@ class FilterPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("broadcast hint") {
-    val originalQuery = ResolvedHint(testRelation)
-      .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-
-    val correctAnswer = ResolvedHint(testRelation.where('a === 2L))
-      .where('b + Rand(10).as("rnd") === 3)
-      .analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
   test("union") {
     val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 6fe5e619d03ad..9093d7fecb0f7 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -65,7 +65,8 @@ class JoinOptimizationSuite extends PlanTest {
 
     def testExtractCheckCross
         (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], 
Seq[Expression])]) {
-      assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected)
+      assert(
+        ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, 
e._2, Map.empty)))
     }
 
     testExtract(x, None)
@@ -124,29 +125,4 @@ class JoinOptimizationSuite extends PlanTest {
       comparePlans(optimized, queryAnswerPair._2.analyze)
     }
   }
-
-  test("broadcasthint sets relation statistics to smallest value") {
-    val input = LocalRelation('key.int, 'value.string)
-
-    val query =
-      Project(Seq($"x.key", $"y.key"),
-        Join(
-          SubqueryAlias("x", input),
-          ResolvedHint(SubqueryAlias("y", input)), Cross, None)).analyze
-
-    val optimized = Optimize.execute(query)
-
-    val expected =
-      Join(
-        Project(Seq($"x.key"), SubqueryAlias("x", input)),
-        ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
-        Cross, None).analyze
-
-    comparePlans(optimized, expected)
-
-    val broadcastChildren = optimized.collect {
-      case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r
-    }
-    assert(broadcastChildren.size == 1)
-  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index c94a8b9e318f6..0dee846205868 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -31,6 +31,8 @@ class JoinReorderSuite extends PlanTest with 
StatsEstimationTestBase {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
     val batches =
+      Batch("Resolve Hints", Once,
+        EliminateResolvedHint) ::
       Batch("Operator Optimizations", FixedPoint(100),
         CombineFilters,
         PushDownPredicate,
@@ -42,6 +44,12 @@ class JoinReorderSuite extends PlanTest with 
StatsEstimationTestBase {
         CostBasedJoinReorder) :: Nil
   }
 
+  object ResolveHints extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Resolve Hints", Once,
+        EliminateResolvedHint) :: Nil
+  }
+
   var originalConfCBOEnabled = false
   var originalConfJoinReorderEnabled = false
 
@@ -284,12 +292,85 @@ class JoinReorderSuite extends PlanTest with 
StatsEstimationTestBase {
     assertEqualPlans(originalPlan, bestPlan)
   }
 
+  test("hints preservation") {
+    // Apply hints if we find an equivalent node in the new plan, otherwise 
discard them.
+    val originalPlan =
+      
t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast"))
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+    val bestPlan =
+    t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
+      .hint("broadcast")
+      .join(
+        t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === 
nameToAttr("t3.v-1-100")))
+          .hint("broadcast"),
+        Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+
+    assertEqualPlans(originalPlan, bestPlan)
+
+    val originalPlan2 =
+      
t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast"))
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+    val bestPlan2 =
+      t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
+        .hint("broadcast")
+        .join(
+          t4.hint("broadcast")
+            .join(t3, Inner, Some(nameToAttr("t4.v-1-10") === 
nameToAttr("t3.v-1-100"))),
+          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+        .select(outputsOf(t1, t2, t3, t4): _*)
+
+    assertEqualPlans(originalPlan2, bestPlan2)
+
+    val originalPlan3 =
+      t1.join(t4).hint("broadcast")
+        .join(t2.hint("broadcast")).hint("broadcast")
+        .join(t3.hint("broadcast"))
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+        (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+        (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+    val bestPlan3 =
+      t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
+        .join(
+          t4.join(t3.hint("broadcast"),
+            Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
+          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+        .select(outputsOf(t1, t4, t2, t3): _*)
+
+    assertEqualPlans(originalPlan3, bestPlan3)
+
+    val originalPlan4 =
+      t2.hint("broadcast")
+        .join(t4).hint("broadcast")
+        .join(t3.hint("broadcast")).hint("broadcast")
+        .join(t1)
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+    val bestPlan4 =
+      t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
+        .join(
+          t4.join(t3.hint("broadcast"),
+            Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
+          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+        .select(outputsOf(t2, t4, t3, t1): _*)
+
+    assertEqualPlans(originalPlan4, bestPlan4)
+  }
+
   private def assertEqualPlans(
       originalPlan: LogicalPlan,
       groundTruthBestPlan: LogicalPlan): Unit = {
     val analyzed = originalPlan.analyze
     val optimized = Optimize.execute(analyzed)
-    val expected = groundTruthBestPlan.analyze
+    val expected = ResolveHints.execute(groundTruthBestPlan.analyze)
 
     assert(analyzed.sameOutput(expected)) // if this fails, the expected plan 
itself is incorrect
     assert(analyzed.sameOutput(optimized))
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index c8e15c7da763e..6d1af12e68b23 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -48,7 +48,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze
+        Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd), 
JoinHint.NONE)).analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -160,7 +160,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze
+        Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd), 
JoinHint.NONE)).analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -175,7 +175,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(left.output, right.output,
-        Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze
+        Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"), 
JoinHint.NONE)).analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -248,7 +248,7 @@ class ReplaceOperatorSuite extends PlanTest {
     val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) 
=>
       a1 <=> a2 }.reduce( _ && _)
     val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
-      Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+      Join(basePlan, otherPlan, LeftAnti, Option(condition), 
JoinHint.NONE)).analyze
     comparePlans(result, correctAnswer)
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 3081ff935f043..5394732f41f2d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -99,11 +99,11 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { 
self: Suite =>
           .reduce(And), child)
       case sample: Sample =>
         sample.copy(seed = 0L)
-      case Join(left, right, joinType, condition) if condition.isDefined =>
+      case Join(left, right, joinType, condition, hint) if condition.isDefined 
=>
         val newCondition =
           
splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode())
             .reduce(And)
-        Join(left, right, joinType, Some(newCondition))
+        Join(left, right, joinType, Some(newCondition), hint)
     }
   }
 
@@ -165,8 +165,10 @@ trait PlanTestBase extends PredicateHelper with SQLHelper 
{ self: Suite =>
   private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
     (plan1, plan2) match {
       case (j1: Join, j2: Join) =>
-        (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
-          (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
+        (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)
+          && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == 
j2.hint.rightHint) ||
+          (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)
+            && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == 
j2.hint.leftHint)
       case (p1: Project, p2: Project) =>
         p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child)
       case _ =>
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
index 7c8ed78a49116..fbaaf807af5d6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.plans
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan, ResolvedHint, Union}
+import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.util._
 
 /**
@@ -30,6 +32,10 @@ class SameResultSuite extends SparkFunSuite {
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
   val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)
 
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("EliminateResolvedHint", Once, EliminateResolvedHint) 
:: Nil
+  }
+
   def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = 
true): Unit = {
     val aAnalyzed = a.analyze
     val bAnalyzed = b.analyze
@@ -72,4 +78,12 @@ class SameResultSuite extends SparkFunSuite {
     val df2 = testRelation.join(testRelation)
     assertSameResult(df1, df2)
   }
+
+  test("join hint") {
+    val df1 = testRelation.join(testRelation.hint("broadcast"))
+    val df2 = testRelation.join(testRelation)
+    val df1Optimized = Optimize.execute(df1.analyze)
+    val df2Optimized = Optimize.execute(df2.analyze)
+    assertSameResult(df1Optimized, df2Optimized)
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index 953094cb0dd52..16a5c2d3001a7 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -38,24 +38,6 @@ class BasicStatsEstimationSuite extends PlanTest with 
StatsEstimationTestBase {
     // row count * (overhead + column size)
     size = Some(10 * (8 + 4)))
 
-  test("BroadcastHint estimation") {
-    val filter = Filter(Literal(true), plan)
-    val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4),
-      rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> 
colStat)))
-    val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4))
-    checkStats(
-      filter,
-      expectedStatsCboOn = filterStatsCboOn,
-      expectedStatsCboOff = filterStatsCboOff)
-
-    val broadcastHint = ResolvedHint(filter, HintInfo(broadcast = true))
-    checkStats(
-      broadcastHint,
-      expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(broadcast = 
true)),
-      expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(broadcast 
= true))
-    )
-  }
-
   test("range") {
     val range = Range(1, 5, 1, None)
     val rangeStats = Statistics(sizeInBytes = 4 * 8)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index b0a47e7835129..1cf888519077a 100755
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -528,7 +528,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase 
{
       rowCount = 30,
       attributeStats = AttributeMap(Seq(attrIntLargerRange -> 
colStatIntLargerRange)))
     val nonLeafChild = Join(largerTable, smallerTable, LeftOuter,
-      Some(EqualTo(attrIntLargerRange, attrInt)))
+      Some(EqualTo(attrIntLargerRange, attrInt)), JoinHint.NONE)
 
     Seq(IsNull(attrIntLargerRange), IsNotNull(attrIntLargerRange)).foreach { 
predicate =>
       validateEstimatedStats(
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index 12c0a7be21292..6c5a2b247fc23 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -79,8 +79,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax)
     val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax)
 
-    val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)))
-    val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)))
+    val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)), 
JoinHint.NONE)
+    val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)), 
JoinHint.NONE)
     val expectedStatsAfterJoin = Statistics(
       sizeInBytes = expectedRows * (8 + 2 * 4),
       rowCount = Some(expectedRows),
@@ -284,7 +284,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
   test("cross join") {
     // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 
5)
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
-    val join = Join(table1, table2, Cross, None)
+    val join = Join(table1, table2, Cross, None, JoinHint.NONE)
     val expectedStats = Statistics(
       sizeInBytes = 5 * 3 * (8 + 4 * 4),
       rowCount = Some(5 * 3),
@@ -299,7 +299,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     // key-5-9 and key-2-4 are disjoint
     val join = Join(table1, table2, Inner,
-      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), 
JoinHint.NONE)
     val expectedStats = Statistics(
       sizeInBytes = 1,
       rowCount = Some(0),
@@ -312,7 +312,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     // key-5-9 and key-2-4 are disjoint
     val join = Join(table1, table2, LeftOuter,
-      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), 
JoinHint.NONE)
     val expectedStats = Statistics(
       sizeInBytes = 5 * (8 + 4 * 4),
       rowCount = Some(5),
@@ -328,7 +328,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     // key-5-9 and key-2-4 are disjoint
     val join = Join(table1, table2, RightOuter,
-      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), 
JoinHint.NONE)
     val expectedStats = Statistics(
       sizeInBytes = 3 * (8 + 4 * 4),
       rowCount = Some(3),
@@ -344,7 +344,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     // key-5-9 and key-2-4 are disjoint
     val join = Join(table1, table2, FullOuter,
-      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))))
+      Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), 
JoinHint.NONE)
     val expectedStats = Statistics(
       sizeInBytes = (5 + 3) * (8 + 4 * 4),
       rowCount = Some(5 + 3),
@@ -361,7 +361,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 
5)
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     val join = Join(table1, table2, Inner,
-      Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2"))))
+      Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2"))), 
JoinHint.NONE)
     // Update column stats for equi-join keys (key-1-5 and key-1-2).
     val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max 
= Some(2),
       nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
@@ -383,7 +383,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
     val join = Join(table2, table3, Inner, Some(
       And(EqualTo(nameToAttr("key-1-2"), nameToAttr("key-1-2")),
-        EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))))
+        EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))), JoinHint.NONE)
 
     // Update column stats for join keys.
     val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), 
max = Some(2),
@@ -404,7 +404,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
     val join = Join(table3, table2, LeftOuter,
-      Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4"))))
+      Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4"))), 
JoinHint.NONE)
     val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max 
= Some(3),
       nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
 
@@ -422,7 +422,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
     val join = Join(table2, table3, RightOuter,
-      Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))
+      Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), 
JoinHint.NONE)
     val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max 
= Some(3),
       nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
 
@@ -440,7 +440,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
     // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
     val join = Join(table2, table3, FullOuter,
-      Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))
+      Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), 
JoinHint.NONE)
 
     val expectedStats = Statistics(
       sizeInBytes = 3 * (8 + 4 * 4),
@@ -456,7 +456,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
     Seq(LeftSemi, LeftAnti).foreach { jt =>
       val join = Join(table2, table3, jt,
-        Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))
+        Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), 
JoinHint.NONE)
       // For now we just propagate the statistics from left side for left 
semi/anti join.
       val expectedStats = Statistics(
         sizeInBytes = 3 * (8 + 4 * 2),
@@ -525,7 +525,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
       withClue(s"For data type ${key1.dataType}") {
         // All values in two tables are the same, so column stats after join 
are also the same.
         val join = Join(Project(Seq(key1), table1), Project(Seq(key2), 
table2), Inner,
-          Some(EqualTo(key1, key2)))
+          Some(EqualTo(key1, key2)), JoinHint.NONE)
         val expectedStats = Statistics(
           sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))),
           rowCount = Some(1),
@@ -543,7 +543,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
       outputList = Seq(nullColumn),
       rowCount = 1,
       attributeStats = AttributeMap(Seq(nullColumn -> nullColStat)))
-    val join = Join(table1, nullTable, Inner, 
Some(EqualTo(nameToAttr("key-1-5"), nullColumn)))
+    val join = Join(table1, nullTable, Inner,
+      Some(EqualTo(nameToAttr("key-1-5"), nullColumn)), JoinHint.NONE)
     val expectedStats = Statistics(
       sizeInBytes = 1,
       rowCount = Some(0),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a664c7338badb..44cada086489a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -862,7 +862,7 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   def join(right: Dataset[_]): DataFrame = withPlan {
-    Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+    Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE)
   }
 
   /**
@@ -940,7 +940,7 @@ class Dataset[T] private[sql](
     // Analyze the self join. The assumption is that the analyzer will 
disambiguate left vs right
     // by creating a new instance for one of the branch.
     val joined = sparkSession.sessionState.executePlan(
-      Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), 
None))
+      Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), 
None, JoinHint.NONE))
       .analyzed.asInstanceOf[Join]
 
     withPlan {
@@ -948,7 +948,8 @@ class Dataset[T] private[sql](
         joined.left,
         joined.right,
         UsingJoin(JoinType(joinType), usingColumns),
-        None)
+        None,
+        JoinHint.NONE)
     }
   }
 
@@ -1001,7 +1002,7 @@ class Dataset[T] private[sql](
     // Trigger analysis so in the case of self-join, the analyzer will clone 
the plan.
     // After the cloning, left and right side will have distinct expression 
ids.
     val plan = withPlan(
-      Join(logicalPlan, right.logicalPlan, JoinType(joinType), 
Some(joinExprs.expr)))
+      Join(logicalPlan, right.logicalPlan, JoinType(joinType), 
Some(joinExprs.expr), JoinHint.NONE))
       .queryExecution.analyzed.asInstanceOf[Join]
 
     // If auto self join alias is disabled, return the plan.
@@ -1048,7 +1049,7 @@ class Dataset[T] private[sql](
    * @since 2.1.0
    */
   def crossJoin(right: Dataset[_]): DataFrame = withPlan {
-    Join(logicalPlan, right.logicalPlan, joinType = Cross, None)
+    Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE)
   }
 
   /**
@@ -1083,7 +1084,8 @@ class Dataset[T] private[sql](
         this.logicalPlan,
         other.logicalPlan,
         JoinType(joinType),
-        Some(condition.expr))).analyzed.asInstanceOf[Join]
+        Some(condition.expr),
+        JoinHint.NONE)).analyzed.asInstanceOf[Join]
 
     if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
       throw new AnalysisException("Invalid join type in joinWith: " + 
joined.joinType.sql)
@@ -1135,7 +1137,7 @@ class Dataset[T] private[sql](
     implicit val tuple2Encoder: Encoder[(T, U)] =
       ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
 
-    withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr)))
+    withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr), 
JoinHint.NONE))
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index dbc6db62bd820..b7cc373b2df12 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -208,17 +208,17 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       }
     }
 
-    private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, 
right: LogicalPlan)
-      : Boolean = {
-      val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
-      val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+    private def canBroadcastByHints(
+        joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: 
JoinHint): Boolean = {
+      val buildLeft = canBuildLeft(joinType) && 
hint.leftHint.exists(_.broadcast)
+      val buildRight = canBuildRight(joinType) && 
hint.rightHint.exists(_.broadcast)
       buildLeft || buildRight
     }
 
-    private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, 
right: LogicalPlan)
-      : BuildSide = {
-      val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
-      val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+    private def broadcastSideByHints(
+        joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: 
JoinHint): BuildSide = {
+      val buildLeft = canBuildLeft(joinType) && 
hint.leftHint.exists(_.broadcast)
+      val buildRight = canBuildRight(joinType) && 
hint.rightHint.exists(_.broadcast)
       broadcastSide(buildLeft, buildRight, left, right)
     }
 
@@ -241,14 +241,14 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       // --- BroadcastHashJoin 
--------------------------------------------------------------------
 
       // broadcast hints were specified
-      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
-        if canBroadcastByHints(joinType, left, right) =>
-        val buildSide = broadcastSideByHints(joinType, left, right)
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right, hint)
+        if canBroadcastByHints(joinType, left, right, hint) =>
+        val buildSide = broadcastSideByHints(joinType, left, right, hint)
         Seq(joins.BroadcastHashJoinExec(
           leftKeys, rightKeys, joinType, buildSide, condition, 
planLater(left), planLater(right)))
 
       // broadcast hints were not specified, so need to infer it from size and 
configuration.
-      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right, _)
         if canBroadcastBySizes(joinType, left, right) =>
         val buildSide = broadcastSideBySizes(joinType, left, right)
         Seq(joins.BroadcastHashJoinExec(
@@ -256,14 +256,14 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
       // --- ShuffledHashJoin 
---------------------------------------------------------------------
 
-      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right, _)
          if !conf.preferSortMergeJoin && canBuildRight(joinType) && 
canBuildLocalHashMap(right)
            && muchSmaller(right, left) ||
            !RowOrdering.isOrderable(leftKeys) =>
         Seq(joins.ShuffledHashJoinExec(
           leftKeys, rightKeys, joinType, BuildRight, condition, 
planLater(left), planLater(right)))
 
-      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right, _)
          if !conf.preferSortMergeJoin && canBuildLeft(joinType) && 
canBuildLocalHashMap(left)
            && muchSmaller(left, right) ||
            !RowOrdering.isOrderable(leftKeys) =>
@@ -272,7 +272,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
       // --- SortMergeJoin 
------------------------------------------------------------
 
-      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right, _)
         if RowOrdering.isOrderable(leftKeys) =>
         joins.SortMergeJoinExec(
           leftKeys, rightKeys, joinType, condition, planLater(left), 
planLater(right)) :: Nil
@@ -280,25 +280,25 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       // --- Without joining keys 
------------------------------------------------------------
 
       // Pick BroadcastNestedLoopJoin if one side could be broadcast
-      case j @ logical.Join(left, right, joinType, condition)
-          if canBroadcastByHints(joinType, left, right) =>
-        val buildSide = broadcastSideByHints(joinType, left, right)
+      case j @ logical.Join(left, right, joinType, condition, hint)
+          if canBroadcastByHints(joinType, left, right, hint) =>
+        val buildSide = broadcastSideByHints(joinType, left, right, hint)
         joins.BroadcastNestedLoopJoinExec(
           planLater(left), planLater(right), buildSide, joinType, condition) 
:: Nil
 
-      case j @ logical.Join(left, right, joinType, condition)
+      case j @ logical.Join(left, right, joinType, condition, _)
           if canBroadcastBySizes(joinType, left, right) =>
         val buildSide = broadcastSideBySizes(joinType, left, right)
         joins.BroadcastNestedLoopJoinExec(
           planLater(left), planLater(right), buildSide, joinType, condition) 
:: Nil
 
       // Pick CartesianProduct for InnerJoin
-      case logical.Join(left, right, _: InnerLike, condition) =>
+      case logical.Join(left, right, _: InnerLike, condition, _) =>
         joins.CartesianProductExec(planLater(left), planLater(right), 
condition) :: Nil
 
-      case logical.Join(left, right, joinType, condition) =>
+      case logical.Join(left, right, joinType, condition, hint) =>
         val buildSide = broadcastSide(
-          left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
+          hint.leftHint.exists(_.broadcast), 
hint.rightHint.exists(_.broadcast), left, right)
         // This join could be very slow or OOM
         joins.BroadcastNestedLoopJoinExec(
           planLater(left), planLater(right), buildSide, joinType, condition) 
:: Nil
@@ -380,13 +380,13 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
   object StreamingJoinStrategy extends Strategy {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
       plan match {
-        case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, 
left, right)
+        case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, 
left, right, _)
           if left.isStreaming && right.isStreaming =>
 
           new StreamingSymmetricHashJoinExec(
             leftKeys, rightKeys, joinType, condition, planLater(left), 
planLater(right)) :: Nil
 
-        case Join(left, right, _, _) if left.isStreaming && right.isStreaming 
=>
+        case Join(left, right, _, _, _) if left.isStreaming && 
right.isStreaming =>
           throw new AnalysisException(
             "Stream-stream join without equality predicate is not supported", 
plan = Some(plan))
 
@@ -561,6 +561,9 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         throw new IllegalStateException(
           "logical except (all) operator should have been replaced by union, 
aggregate" +
             " and generate operators in the optimizer")
+      case logical.ResolvedHint(child, hints) =>
+        throw new IllegalStateException(
+          "ResolvedHint operator should have been replaced by join hint in the 
optimizer")
 
       case logical.DeserializeToObject(deserializer, objAttr, child) =>
         execution.DeserializeToObjectExec(deserializer, objAttr, 
planLater(child)) :: Nil
@@ -632,7 +635,6 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case ExternalRDD(outputObjAttr, rdd) => 
ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
       case r: LogicalRDD =>
         RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, 
r.outputOrdering) :: Nil
-      case h: ResolvedHint => planLater(h.child) :: Nil
       case _ => Nil
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 4109d9994dd8f..41f406d6c2993 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, 
Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.storage.StorageLevel
@@ -184,12 +184,7 @@ case class InMemoryRelation(
   override def computeStats(): Statistics = {
     if (cacheBuilder.sizeInBytesStats.value == 0L) {
       // Underlying columnar RDD hasn't been materialized, use the stats from 
the plan to cache.
-      // Note that we should drop the hint info here. We may cache a plan 
whose root node is a hint
-      // node. When we lookup the cache with a semantically same plan without 
hint info, the plan
-      // returned by cache lookup should not have hint info. If we lookup the 
cache with a
-      // semantically same plan with a different hint info, 
`CacheManager.useCachedData` will take
-      // care of it and retain the hint info in the lookup input plan.
-      statsOfPlanToCache.copy(hints = HintInfo())
+      statsOfPlanToCache
     } else {
       Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 6e805c4f3c39a..2141be4d680f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -27,9 +27,11 @@ import 
org.apache.spark.executor.DataReadMethod.DataReadMethod
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
+import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
 import org.apache.spark.storage.{RDDBlockId, StorageLevel}
@@ -925,4 +927,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils 
with SharedSQLContext
       }
     }
   }
+
+  test("Cache should respect the broadcast hint") {
+    val df = broadcast(spark.range(1000)).cache()
+    val df2 = spark.range(1000).cache()
+    df.count()
+    df2.count()
+
+    // Test the broadcast hint.
+    val joinPlan = df.join(df2, "id").queryExecution.optimizedPlan
+    val hint = joinPlan.collect {
+      case Join(_, _, _, _, hint) => hint
+    }
+    assert(hint.size == 1)
+    assert(hint(0).leftHint.get.broadcast)
+    assert(hint(0).rightHint.isEmpty)
+
+    // Clean-up
+    df.unpersist()
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c9f41ab1c0179..a4a3e2a62d1a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -198,7 +198,7 @@ class DataFrameJoinSuite extends QueryTest with 
SharedSQLContext {
     // outer -> left
     val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", 
"outer").where($"a.int" >= 3)
     assert(outerJoin2Left.queryExecution.optimizedPlan.collect {
-      case j @ Join(_, _, LeftOuter, _) => j }.size === 1)
+      case j @ Join(_, _, LeftOuter, _, _) => j }.size === 1)
     checkAnswer(
       outerJoin2Left,
       Row(3, 4, "3", null, null, null) :: Nil)
@@ -206,7 +206,7 @@ class DataFrameJoinSuite extends QueryTest with 
SharedSQLContext {
     // outer -> right
     val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", 
"outer").where($"b.int" >= 3)
     assert(outerJoin2Right.queryExecution.optimizedPlan.collect {
-      case j @ Join(_, _, RightOuter, _) => j }.size === 1)
+      case j @ Join(_, _, RightOuter, _, _) => j }.size === 1)
     checkAnswer(
       outerJoin2Right,
       Row(null, null, null, 5, 6, "5") :: Nil)
@@ -215,7 +215,7 @@ class DataFrameJoinSuite extends QueryTest with 
SharedSQLContext {
     val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer").
       where($"a.int" === 1 && $"b.int2" === 3)
     assert(outerJoin2Inner.queryExecution.optimizedPlan.collect {
-      case j @ Join(_, _, Inner, _) => j }.size === 1)
+      case j @ Join(_, _, Inner, _, _) => j }.size === 1)
     checkAnswer(
       outerJoin2Inner,
       Row(1, 2, "1", 1, 3, "1") :: Nil)
@@ -223,7 +223,7 @@ class DataFrameJoinSuite extends QueryTest with 
SharedSQLContext {
     // right -> inner
     val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", 
"right").where($"a.int" > 0)
     assert(rightJoin2Inner.queryExecution.optimizedPlan.collect {
-      case j @ Join(_, _, Inner, _) => j }.size === 1)
+      case j @ Join(_, _, Inner, _, _) => j }.size === 1)
     checkAnswer(
       rightJoin2Inner,
       Row(1, 2, "1", 1, 3, "1") :: Nil)
@@ -231,7 +231,7 @@ class DataFrameJoinSuite extends QueryTest with 
SharedSQLContext {
     // left -> inner
     val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", 
"left").where($"b.int2" > 0)
     assert(leftJoin2Inner.queryExecution.optimizedPlan.collect {
-      case j @ Join(_, _, Inner, _) => j }.size === 1)
+      case j @ Join(_, _, Inner, _, _) => j }.size === 1)
     checkAnswer(
       leftJoin2Inner,
       Row(1, 2, "1", 1, 3, "1") :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
new file mode 100644
index 0000000000000..3652895ff43d8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class JoinHintSuite extends PlanTest with SharedSQLContext {
+  import testImplicits._
+
+  lazy val df = spark.range(10)
+  lazy val df1 = df.selectExpr("id as a1", "id as a2")
+  lazy val df2 = df.selectExpr("id as b1", "id as b2")
+  lazy val df3 = df.selectExpr("id as c1", "id as c2")
+
+  def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = {
+    val optimized = df.queryExecution.optimizedPlan
+    val joinHints = optimized collect {
+      case Join(_, _, _, _, hint) => hint
+      case _: ResolvedHint => fail("ResolvedHint should not appear after 
optimize.")
+    }
+    assert(joinHints == expectedHints)
+  }
+
+  test("single join") {
+    verifyJoinHint(
+      df.hint("broadcast").join(df, "id"),
+      JoinHint(
+        Some(HintInfo(broadcast = true)),
+        None) :: Nil
+    )
+    verifyJoinHint(
+      df.join(df.hint("broadcast"), "id"),
+      JoinHint(
+        None,
+        Some(HintInfo(broadcast = true))) :: Nil
+    )
+  }
+
+  test("multiple joins") {
+    verifyJoinHint(
+      df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"), 
'a1 === 'c1),
+      JoinHint(
+        None,
+        Some(HintInfo(broadcast = true))) ::
+        JoinHint(
+          Some(HintInfo(broadcast = true)),
+          None) :: Nil
+    )
+    verifyJoinHint(
+      df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3, 
'a1 === 'c1),
+      JoinHint(
+        Some(HintInfo(broadcast = true)),
+        None) ::
+        JoinHint(
+          Some(HintInfo(broadcast = true)),
+          None) :: Nil
+    )
+  }
+
+  test("hint scope") {
+    withTempView("a", "b", "c") {
+      df1.createOrReplaceTempView("a")
+      df2.createOrReplaceTempView("b")
+      verifyJoinHint(
+        sql(
+          """
+            |select /*+ broadcast(a, b)*/ * from (
+            |  select /*+ broadcast(b)*/ * from a join b on a.a1 = b.b1
+            |) a join (
+            |  select /*+ broadcast(a)*/ * from a join b on a.a1 = b.b1
+            |) b on a.a1 = b.b1
+          """.stripMargin),
+        JoinHint(
+          Some(HintInfo(broadcast = true)),
+          Some(HintInfo(broadcast = true))) ::
+          JoinHint(
+            None,
+            Some(HintInfo(broadcast = true))) ::
+          JoinHint(
+            Some(HintInfo(broadcast = true)),
+            None) :: Nil
+      )
+    }
+  }
+
+  test("hint preserved after join reorder") {
+    withTempView("a", "b", "c") {
+      df1.createOrReplaceTempView("a")
+      df2.createOrReplaceTempView("b")
+      df3.createOrReplaceTempView("c")
+      verifyJoinHint(
+        sql("select /*+ broadcast(a, c)*/ * from a, b, c " +
+          "where a.a1 = b.b1 and b.b1 = c.c1"),
+        JoinHint(
+          None,
+          Some(HintInfo(broadcast = true))) ::
+          JoinHint(
+            Some(HintInfo(broadcast = true)),
+            None):: Nil
+      )
+      verifyJoinHint(
+        sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
+          "where a.a1 = b.b1 and b.b1 = c.c1"),
+        JoinHint(
+          None,
+          Some(HintInfo(broadcast = true))) ::
+          JoinHint(
+            Some(HintInfo(broadcast = true)),
+            None):: Nil
+      )
+      verifyJoinHint(
+        sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
+          "where a.a1 = b.b1 and b.b1 = c.c1"),
+        JoinHint(
+          None,
+          Some(HintInfo(broadcast = true))) ::
+          JoinHint(
+            None,
+            Some(HintInfo(broadcast = true))):: Nil
+      )
+
+      verifyJoinHint(
+        df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
+          .join(df3, 'b1 === 'c1 && 'a1 < 10),
+        JoinHint(
+          Some(HintInfo(broadcast = true)),
+          None) ::
+          JoinHint.NONE:: Nil
+      )
+
+      verifyJoinHint(
+        df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
+          .join(df3, 'b1 === 'c1 && 'a1 < 10)
+          .join(df, 'b1 === 'id),
+        JoinHint.NONE ::
+          JoinHint(
+            Some(HintInfo(broadcast = true)),
+            None) ::
+          JoinHint.NONE:: Nil
+      )
+    }
+  }
+
+  test("intersect/except") {
+    val dfSub = spark.range(2)
+    verifyJoinHint(
+      df.hint("broadcast").except(dfSub).join(df, "id"),
+      JoinHint(
+        Some(HintInfo(broadcast = true)),
+        None) ::
+        JoinHint.NONE :: Nil
+    )
+    verifyJoinHint(
+      df.join(df.hint("broadcast").intersect(dfSub), "id"),
+      JoinHint(
+        None,
+        Some(HintInfo(broadcast = true))) ::
+        JoinHint.NONE :: Nil
+    )
+  }
+
+  test("hint merge") {
+    verifyJoinHint(
+      df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"),
+      JoinHint(
+        Some(HintInfo(broadcast = true)),
+        None) :: Nil
+    )
+    verifyJoinHint(
+      df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"),
+      JoinHint(
+        None,
+        Some(HintInfo(broadcast = true))) :: Nil
+    )
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index 02dc32d5f90ba..99842680cedfe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -237,8 +237,7 @@ class StatisticsCollectionSuite extends 
StatisticsCollectionTestBase with Shared
     )
     numbers.foreach { case (input, (expectedSize, expectedRows)) =>
       val stats = Statistics(sizeInBytes = input, rowCount = Some(input))
-      val expectedString = s"sizeInBytes=$expectedSize, 
rowCount=$expectedRows," +
-        s" hints=none"
+      val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows"
       assert(stats.simpleString == expectedString)
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 42dd0024b2582..f238148e61c39 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -203,7 +203,7 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
   }
 
   test("broadcast hint in SQL") {
-    import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join}
+    import org.apache.spark.sql.catalyst.plans.logical.Join
 
     spark.range(10).createOrReplaceTempView("t")
     spark.range(10).createOrReplaceTempView("u")
@@ -216,12 +216,12 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
       val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = 
u.id").queryExecution
         .optimizedPlan
 
-      assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
-      assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
-      assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
-      assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
-      assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
-      assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
+      assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast)
+      assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty)
+      assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty)
+      assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast)
+      assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty)
+      assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty)
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 22279a3a43eff..771a9730247af 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
 import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, 
SparkPlanTest}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.internal.SQLConf
@@ -85,7 +85,8 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       expectedAnswer: Seq[Row]): Unit = {
 
     def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
-      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, 
Some(condition))
+      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
+        Inner, Some(condition), JoinHint.NONE)
       ExtractEquiJoinKeys.unapply(join)
     }
 
@@ -102,7 +103,7 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     test(s"$testName using ShuffledHashJoin") {
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
             EnsureRequirements(left.sqlContext.sessionState.conf).apply(
@@ -121,7 +122,7 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { 
_ =>
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
             EnsureRequirements(left.sqlContext.sessionState.conf).apply(
@@ -140,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     test(s"$testName using SortMergeJoin") {
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
             EnsureRequirements(left.sqlContext.sessionState.conf).apply(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index f5edd6bbd5e69..f99a278bb2427 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.internal.SQLConf
@@ -80,7 +80,8 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       expectedAnswer: Seq[Product]): Unit = {
 
     def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
-      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, 
Some(condition()))
+      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
+        Inner, Some(condition()), JoinHint.NONE)
       ExtractEquiJoinKeys.unapply(join)
     }
 
@@ -128,7 +129,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin 
(build=left)") { _ =>
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>
             makeBroadcastHashJoin(
@@ -140,7 +141,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin 
(build=right)") { _ =>
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>
             makeBroadcastHashJoin(
@@ -152,7 +153,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     test(s"$testName using ShuffledHashJoin (build=left)") {
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>
             makeShuffledHashJoin(
@@ -164,7 +165,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     test(s"$testName using ShuffledHashJoin (build=right)") {
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>
             makeShuffledHashJoin(
@@ -176,7 +177,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>
             makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, 
rightPlan),
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 513248dae48be..1f04fcf6ca451 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
 import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.internal.SQLConf
@@ -72,13 +72,14 @@ class OuterJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       expectedAnswer: Seq[Product]): Unit = {
 
     def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
-      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, 
Some(condition))
+      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
+        Inner, Some(condition), JoinHint.NONE)
       ExtractEquiJoinKeys.unapply(join)
     }
 
     if (joinType != FullOuter) {
       test(s"$testName using ShuffledHashJoin") {
-        extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+        extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
           withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
             val buildSide = if (joinType == LeftOuter) BuildRight else 
BuildLeft
             checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
@@ -99,7 +100,7 @@ class OuterJoinSuite extends SparkPlanTest with 
SharedSQLContext {
           case RightOuter => BuildLeft
           case _ => fail(s"Unsupported join type $joinType")
         }
-        extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+        extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
           withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
             checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
               BroadcastHashJoinExec(
@@ -112,7 +113,7 @@ class OuterJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     test(s"$testName using SortMergeJoin") {
-      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
             EnsureRequirements(spark.sessionState.conf).apply(


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to