diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index 743d3ce944fe2..6540e95b01e3f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -31,40 +31,6 @@ import org.apache.spark.sql.internal.SQLConf
  * Cost-based join reorder.
  * We may have several join reorder algorithms in the future. This class is 
the entry of these
  * algorithms, and chooses which one to use.
- *
- * Note that join strategy hints, e.g. the broadcast hint, do not interfere 
with the reordering.
- * Such hints will be applied on the equivalent counterparts (i.e., join 
between the same relations
- * regardless of the join order) of the original nodes after reordering.
- * For example, the plan before reordering is like:
- *
- *            Join
- *            /   \
- *          Hint1 t4
- *          /
- *        Join
- *        / \
- *      Join t3
- *      /   \
- *    Hint2 t2
- *    /
- *   t1
- *
- * The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) 
JOIN t4", and after
- * reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the 
new plan will be like:
- *
- *            Join
- *            /   \
- *          Hint1 t4
- *          /
- *        Join
- *        / \
- *      Join t2
- *      / \
- *    t1  t3
- *
- * "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the 
original hinted node,
- * "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan 
since there is no
- * equivalent node to "t1 JOIN t2".
  */
 object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
 
@@ -74,30 +40,24 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with 
PredicateHelper {
     if (!conf.cboEnabled || !conf.joinReorderEnabled) {
       plan
     } else {
-      // Use a map to track the hints on the join items.
-      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
       val result = plan transformDown {
         // Start reordering with a joinable item, which is an InnerLike join 
with conditions.
-        case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
-          reorder(j, j.output, hintMap)
-        case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _))
-          if projectList.forall(_.isInstanceOf[Attribute]) =>
-          reorder(p, p.output, hintMap)
+        // Avoid reordering if a join hint is present.
+        case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == 
JoinHint.NONE =>
+          reorder(j, j.output)
+        case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), 
hint))
+          if projectList.forall(_.isInstanceOf[Attribute]) && hint == 
JoinHint.NONE =>
+          reorder(p, p.output)
       }
       // After reordering is finished, convert OrderedJoin back to Join.
       result transform {
-        case OrderedJoin(left, right, jt, cond) =>
-          val joinHint = JoinHint(hintMap.get(left.outputSet), 
hintMap.get(right.outputSet))
-          Join(left, right, jt, cond, joinHint)
+        case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, 
JoinHint.NONE)
       }
     }
   }
 
-  private def reorder(
-      plan: LogicalPlan,
-      output: Seq[Attribute],
-      hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = {
-    val (items, conditions) = extractInnerJoins(plan, hintMap)
+  private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan 
= {
+    val (items, conditions) = extractInnerJoins(plan)
     val result =
       // Do reordering if the number of items is appropriate and join 
conditions exist.
       // We also need to check if costs of all items can be evaluated.
@@ -115,20 +75,16 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with 
PredicateHelper {
    * Extracts items of consecutive inner joins and join conditions.
    * This method works for bushy trees and left/right deep trees.
    */
-  private def extractInnerJoins(
-      plan: LogicalPlan,
-      hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], 
Set[Expression]) = {
+  private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], 
Set[Expression]) = {
     plan match {
-      case Join(left, right, _: InnerLike, Some(cond), hint) =>
-        hint.leftHint.foreach(hintMap.put(left.outputSet, _))
-        hint.rightHint.foreach(hintMap.put(right.outputSet, _))
-        val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap)
-        val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap)
+      case Join(left, right, _: InnerLike, Some(cond), _) =>
+        val (leftPlans, leftConditions) = extractInnerJoins(left)
+        val (rightPlans, rightConditions) = extractInnerJoins(right)
         (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
           leftConditions ++ rightConditions)
       case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
         if projectList.forall(_.isInstanceOf[Attribute]) =>
-        extractInnerJoins(j, hintMap)
+        extractInnerJoins(j)
       case _ =>
         (Seq(plan), Set())
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 82aefca8a1af6..251ece315f6a8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -43,13 +43,11 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
    *
    * @param input a list of LogicalPlans to inner join and the type of inner 
join.
    * @param conditions a list of condition for join.
-   * @param hintMap a map of relation output attribute sets to their 
corresponding hints.
    */
   @tailrec
   final def createOrderedJoin(
       input: Seq[(LogicalPlan, InnerLike)],
-      conditions: Seq[Expression],
-      hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = {
+      conditions: Seq[Expression]): LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
       val (joinConditions, others) = 
conditions.partition(canEvaluateWithinJoin)
@@ -58,8 +56,8 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
         case (Inner, Inner) => Inner
         case (_, _) => Cross
       }
-      val join = Join(left, right, innerJoinType, 
joinConditions.reduceLeftOption(And),
-        JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
+      val join = Join(left, right, innerJoinType,
+        joinConditions.reduceLeftOption(And), JoinHint.NONE)
       if (others.nonEmpty) {
         Filter(others.reduceLeft(And), join)
       } else {
@@ -82,27 +80,27 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
         e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
-      val joined = Join(left, right, innerJoinType, 
joinConditions.reduceLeftOption(And),
-        JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
+      val joined = Join(left, right, innerJoinType,
+        joinConditions.reduceLeftOption(And), JoinHint.NONE)
 
       // should not have reference to same logical plan
-      createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), 
others, hintMap)
+      createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), 
others)
     }
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap)
+    case p @ ExtractFiltersAndInnerJoins(input, conditions)
         if input.size > 2 && conditions.nonEmpty =>
       val reordered = if (SQLConf.get.starSchemaDetection && 
!SQLConf.get.cboEnabled) {
         val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, 
conditions)
         if (starJoinPlan.nonEmpty) {
           val rest = input.filterNot(starJoinPlan.contains(_))
-          createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap)
+          createOrderedJoin(starJoinPlan ++ rest, conditions)
         } else {
-          createOrderedJoin(input, conditions, hintMap)
+          createOrderedJoin(input, conditions)
         }
       } else {
-        createOrderedJoin(input, conditions, hintMap)
+        createOrderedJoin(input, conditions)
       }
 
       if (p.sameOutput(reordered)) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 95be0a52cb2ed..a816922f49aee 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -166,35 +166,27 @@ object ExtractFiltersAndInnerJoins extends 
PredicateHelper {
    * was involved in an explicit cross join. Also returns the entire list of 
join conditions for
    * the left-deep tree.
    */
-  def flattenJoin(
-      plan: LogicalPlan,
-      hintMap: mutable.HashMap[AttributeSet, HintInfo],
-      parentJoinType: InnerLike = Inner)
+  def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
       : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
-    case Join(left, right, joinType: InnerLike, cond, hint) =>
-      val (plans, conditions) = flattenJoin(left, hintMap, joinType)
-      hint.leftHint.map(hintMap.put(left.outputSet, _))
-      hint.rightHint.map(hintMap.put(right.outputSet, _))
+    case Join(left, right, joinType: InnerLike, cond, hint) if hint == 
JoinHint.NONE =>
+      val (plans, conditions) = flattenJoin(left, joinType)
       (plans ++ Seq((right, joinType)), conditions ++
         cond.toSeq.flatMap(splitConjunctivePredicates))
-    case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) =>
-      val (plans, conditions) = flattenJoin(j, hintMap)
+    case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if 
hint == JoinHint.NONE =>
+      val (plans, conditions) = flattenJoin(j)
       (plans, conditions ++ splitConjunctivePredicates(filterCondition))
 
     case _ => (Seq((plan, parentJoinType)), Seq.empty)
   }
 
   def unapply(plan: LogicalPlan)
-      : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], 
Map[AttributeSet, HintInfo])]
+      : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]
       = plan match {
-    case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, 
_)) =>
-      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
-      val flattened = flattenJoin(f, hintMap)
-      Some((flattened._1, flattened._2, hintMap.toMap))
-    case j @ Join(_, _, joinType, _, _) =>
-      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
-      val flattened = flattenJoin(j, hintMap)
-      Some((flattened._1, flattened._2, hintMap.toMap))
+    case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, 
hint))
+        if hint == JoinHint.NONE =>
+      Some(flattenJoin(f))
+    case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE =>
+      Some(flattenJoin(j))
     case _ => None
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 9093d7fecb0f7..c570643c74106 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -66,7 +66,7 @@ class JoinOptimizationSuite extends PlanTest {
     def testExtractCheckCross
         (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], 
Seq[Expression])]) {
       assert(
-        ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, 
e._2, Map.empty)))
+        ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, 
e._2)))
     }
 
     testExtract(x, None)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index 0dee846205868..f1da0a8e865b0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -292,77 +292,56 @@ class JoinReorderSuite extends PlanTest with 
StatsEstimationTestBase {
     assertEqualPlans(originalPlan, bestPlan)
   }
 
-  test("hints preservation") {
-    // Apply hints if we find an equivalent node in the new plan, otherwise 
discard them.
+  test("don't reorder if hints present") {
     val originalPlan =
-      
t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast"))
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
-
-    val bestPlan =
-    t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
-      .hint("broadcast")
-      .join(
-        t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === 
nameToAttr("t3.v-1-100")))
-          .hint("broadcast"),
-        Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+      t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
+        .hint("broadcast")
+        .join(
+          t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === 
nameToAttr("t3.v-1-100")))
+            .hint("broadcast"),
+          Inner,
+          Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
 
-    assertEqualPlans(originalPlan, bestPlan)
+    assertEqualPlans(originalPlan, originalPlan)
 
     val originalPlan2 =
-      
t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast"))
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
-
-    val bestPlan2 =
       t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
         .hint("broadcast")
-        .join(
-          t4.hint("broadcast")
-            .join(t3, Inner, Some(nameToAttr("t4.v-1-10") === 
nameToAttr("t3.v-1-100"))),
-          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
-        .select(outputsOf(t1, t2, t3, t4): _*)
+        .join(t4, Inner, Some(nameToAttr("t4.v-1-10") === 
nameToAttr("t3.v-1-100")))
+        .hint("broadcast")
+        .join(t3, Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t4.k-1-2")))
 
-    assertEqualPlans(originalPlan2, bestPlan2)
+    assertEqualPlans(originalPlan2, originalPlan2)
+  }
 
-    val originalPlan3 =
-      t1.join(t4).hint("broadcast")
-        .join(t2.hint("broadcast")).hint("broadcast")
-        .join(t3.hint("broadcast"))
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-        (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-        (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+  test("reorder below and above the hint node") {
+    val originalPlan =
+      t1.join(t2).join(t3)
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+          (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+        .hint("broadcast").join(t4)
 
-    val bestPlan3 =
-      t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
-        .join(
-          t4.join(t3.hint("broadcast"),
-            Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
-          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
-        .select(outputsOf(t1, t4, t2, t3): _*)
+    val bestPlan =
+    t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === 
nameToAttr("t3.v-1-100")))
+      .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+      .select(outputsOf(t1, t2, t3): _*)
+      .hint("broadcast").join(t4)
 
-    assertEqualPlans(originalPlan3, bestPlan3)
+    assertEqualPlans(originalPlan, bestPlan)
 
-    val originalPlan4 =
-      t2.hint("broadcast")
-        .join(t4).hint("broadcast")
-        .join(t3.hint("broadcast")).hint("broadcast")
-        .join(t1)
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+    val originalPlan2 =
+      t1.join(t2).join(t3)
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+          (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+        .join(t4.hint("broadcast"))
 
-    val bestPlan4 =
-      t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
-        .join(
-          t4.join(t3.hint("broadcast"),
-            Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
-          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
-        .select(outputsOf(t2, t4, t3, t1): _*)
+    val bestPlan2 =
+      t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === 
nameToAttr("t3.v-1-100")))
+        .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === 
nameToAttr("t2.k-1-5")))
+        .select(outputsOf(t1, t2, t3): _*)
+        .join(t4.hint("broadcast"))
 
-    assertEqualPlans(originalPlan4, bestPlan4)
+    assertEqualPlans(originalPlan2, bestPlan2)
   }
 
   private def assertEqualPlans(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 55f210cb04dbf..30a3d54fd833f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
 class JoinHintSuite extends PlanTest with SharedSQLContext {
@@ -100,7 +101,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
     }
   }
 
-  test("hint preserved after join reorder") {
+  test("hints prevent join reorder") {
     withTempView("a", "b", "c") {
       df1.createOrReplaceTempView("a")
       df2.createOrReplaceTempView("b")
@@ -118,12 +119,10 @@ class JoinHintSuite extends PlanTest with 
SharedSQLContext {
       verifyJoinHint(
         sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
           "where a.a1 = b.b1 and b.b1 = c.c1"),
-        JoinHint(
-          None,
-          Some(HintInfo(broadcast = true))) ::
+        JoinHint.NONE ::
           JoinHint(
             Some(HintInfo(broadcast = true)),
-            None):: Nil
+            Some(HintInfo(broadcast = true))):: Nil
       )
       verifyJoinHint(
         sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
@@ -199,4 +198,21 @@ class JoinHintSuite extends PlanTest with SharedSQLContext 
{
         None) :: Nil
     )
   }
+
+  test("hints prevent cost-based join reorder") {
+    withSQLConf(SQLConf.CBO_ENABLED.key -> "true", 
SQLConf.JOIN_REORDER_ENABLED.key -> "true") {
+      val join = df.join(df, "id")
+      val broadcasted = join.hint("broadcast")
+      verifyJoinHint(
+        join.join(broadcasted, "id").join(broadcasted, "id"),
+        JoinHint(
+          None,
+          Some(HintInfo(broadcast = true))) ::
+          JoinHint(
+            None,
+            Some(HintInfo(broadcast = true))) ::
+          JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil
+      )
+    }
+  }
 }


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to