Repository: spark
Updated Branches:
  refs/heads/master f14f81e90 -> 816391159


[SPARK-17791][SQL] Join reordering using star schema detection

## What changes were proposed in this pull request?

Star schema consists of one or more fact tables referencing a number of 
dimension tables. In general, queries against star schema are expected to run 
fast because of the established RI constraints among the tables. This design 
proposes a join reordering based on natural, generally accepted heuristics for 
star schema queries:
- Finds the star join with the largest fact table and places it on the driving 
arm of the left-deep join. This plan avoids large tables on the inner, and thus 
favors hash joins.
- Applies the most selective dimensions early in the plan to reduce the amount 
of data flow.

The design document was included in SPARK-17791.

Link to the google doc: 
[StarSchemaDetection](https://docs.google.com/document/d/1UAfwbm_A6wo7goHlVZfYK99pqDMEZUumi7pubJXETEA/edit?usp=sharing)

## How was this patch tested?

A new test suite StarJoinSuite.scala was implemented.

Author: Ioana Delaney <ioanamdela...@gmail.com>

Closes #15363 from ioana-delaney/starJoinReord2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/81639115
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/81639115
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/81639115

Branch: refs/heads/master
Commit: 81639115947a13017d1637549a8f66ba599b27b8
Parents: f14f81e
Author: Ioana Delaney <ioanamdela...@gmail.com>
Authored: Mon Mar 20 16:04:58 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Mon Mar 20 16:04:58 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/SimpleCatalystConf.scala |   1 +
 .../optimizer/CostBasedJoinReorder.scala        |   2 +
 .../sql/catalyst/optimizer/Optimizer.scala      |   2 +-
 .../spark/sql/catalyst/optimizer/joins.scala    | 350 ++++++++++-
 .../spark/sql/catalyst/planning/patterns.scala  |   4 +-
 .../org/apache/spark/sql/internal/SQLConf.scala |  16 +
 .../optimizer/JoinOptimizationSuite.scala       |   4 +-
 .../catalyst/optimizer/JoinReorderSuite.scala   |  29 +-
 .../optimizer/StarJoinReorderSuite.scala        | 580 +++++++++++++++++++
 .../spark/sql/catalyst/plans/PlanTest.scala     |  26 +
 10 files changed, 978 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
index 0d4903e..ac97987 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
@@ -40,6 +40,7 @@ case class SimpleCatalystConf(
     override val cboEnabled: Boolean = false,
     override val joinReorderEnabled: Boolean = false,
     override val joinReorderDPThreshold: Int = 12,
+    override val starSchemaDetection: Boolean = false,
     override val warehousePath: String = "/user/hive/warehouse",
     override val sessionLocalTimeZone: String = TimeZone.getDefault().getID,
     override val maxNestedViewDepth: Int = 100)

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index 1b32bda..521c468 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -53,6 +53,8 @@ case class CostBasedJoinReorder(conf: SQLConf) extends 
Rule[LogicalPlan] with Pr
 
   def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = {
     val (items, conditions) = extractInnerJoins(plan)
+    // TODO: Compute the set of star-joins and use them in the join enumeration
+    // algorithm to prune un-optimal plan choices.
     val result =
       // Do reordering if the number of items is appropriate and join 
conditions exist.
       // We also need to check if costs of all items can be evaluated.

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c8ed419..d7524a5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -82,7 +82,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, 
conf: CatalystConf)
     Batch("Operator Optimizations", fixedPoint,
       // Operator push down
       PushProjectionThroughUnion,
-      ReorderJoin,
+      ReorderJoin(conf),
       EliminateOuterJoin,
       PushPredicateThroughJoin,
       PushDownPredicate,

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index bfe529e..58e4a23 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -20,19 +20,347 @@ package org.apache.spark.sql.catalyst.optimizer
 import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
+import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, 
PhysicalOperation}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Encapsulates star-schema join detection.
+ */
+case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
+
+  /**
+   * Star schema consists of one or more fact tables referencing a number of 
dimension
+   * tables. In general, star-schema joins are detected using the following 
conditions:
+   *  1. Informational RI constraints (reliable detection)
+   *    + Dimension contains a primary key that is being joined to the fact 
table.
+   *    + Fact table contains foreign keys referencing multiple dimension 
tables.
+   *  2. Cardinality based heuristics
+   *    + Usually, the table with the highest cardinality is the fact table.
+   *    + Table being joined with the most number of tables is the fact table.
+   *
+   * To detect star joins, the algorithm uses a combination of the above two 
conditions.
+   * The fact table is chosen based on the cardinality heuristics, and the 
dimension
+   * tables are chosen based on the RI constraints. A star join will consist 
of the largest
+   * fact table joined with the dimension tables on their primary keys. To 
detect that a
+   * column is a primary key, the algorithm uses table and column statistics.
+   *
+   * Since Catalyst only supports left-deep tree plans, the algorithm 
currently returns only
+   * the star join with the largest fact table. Choosing the largest fact 
table on the
+   * driving arm to avoid large inners is in general a good heuristic. This 
restriction can
+   * be lifted with support for bushy tree plans.
+   *
+   * The highlights of the algorithm are the following:
+   *
+   * Given a set of joined tables/plans, the algorithm first verifies if they 
are eligible
+   * for star join detection. An eligible plan is a base table access with 
valid statistics.
+   * A base table access represents Project or Filter operators above a 
LeafNode. Conservatively,
+   * the algorithm only considers base table access as part of a star join 
since they provide
+   * reliable statistics.
+   *
+   * If some of the plans are not base table access, or statistics are not 
available, the algorithm
+   * returns an empty star join plan since, in the absence of statistics, it 
cannot make
+   * good planning decisions. Otherwise, the algorithm finds the table with 
the largest cardinality
+   * (number of rows), which is assumed to be a fact table.
+   *
+   * Next, it computes the set of dimension tables for the current fact table. 
A dimension table
+   * is assumed to be in a RI relationship with a fact table. To infer column 
uniqueness,
+   * the algorithm compares the number of distinct values with the total 
number of rows in the
+   * table. If their relative difference is within certain limits (i.e. 
ndvMaxError * 2, adjusted
+   * based on 1TB TPC-DS data), the column is assumed to be unique.
+   */
+  def findStarJoins(
+      input: Seq[LogicalPlan],
+      conditions: Seq[Expression]): Seq[Seq[LogicalPlan]] = {
+
+    val emptyStarJoinPlan = Seq.empty[Seq[LogicalPlan]]
+
+    if (!conf.starSchemaDetection || input.size < 2) {
+      emptyStarJoinPlan
+    } else {
+      // Find if the input plans are eligible for star join detection.
+      // An eligible plan is a base table access with valid statistics.
+      val foundEligibleJoin = input.forall {
+        case PhysicalOperation(_, _, t: LeafNode) if 
t.stats(conf).rowCount.isDefined => true
+        case _ => false
+      }
+
+      if (!foundEligibleJoin) {
+        // Some plans don't have stats or are complex plans. Conservatively,
+        // return an empty star join. This restriction can be lifted
+        // once statistics are propagated in the plan.
+        emptyStarJoinPlan
+      } else {
+        // Find the fact table using cardinality based heuristics i.e.
+        // the table with the largest number of rows.
+        val sortedFactTables = input.map { plan =>
+          TableAccessCardinality(plan, getTableAccessCardinality(plan))
+        }.collect { case t @ TableAccessCardinality(_, Some(_)) =>
+          t
+        }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse)
+
+        sortedFactTables match {
+          case Nil =>
+            emptyStarJoinPlan
+          case table1 :: table2 :: _
+            if table2.size.get.toDouble > conf.starSchemaFTRatio * 
table1.size.get.toDouble =>
+            // If the top largest tables have comparable number of rows, 
return an empty star plan.
+            // This restriction will be lifted when the algorithm is 
generalized
+            // to return multiple star plans.
+            emptyStarJoinPlan
+          case TableAccessCardinality(factTable, _) :: rest =>
+            // Find the fact table joins.
+            val allFactJoins = rest.collect { case 
TableAccessCardinality(plan, _)
+                if findJoinConditions(factTable, plan, conditions).nonEmpty =>
+              plan
+            }
+
+            // Find the corresponding join conditions.
+            val allFactJoinCond = allFactJoins.flatMap { plan =>
+              val joinCond = findJoinConditions(factTable, plan, conditions)
+              joinCond
+            }
+
+            // Verify if the join columns have valid statistics.
+            // Allow any relational comparison between the tables. Later
+            // we will heuristically choose a subset of equi-join
+            // tables.
+            val areStatsAvailable = allFactJoins.forall { dimTable =>
+              allFactJoinCond.exists {
+                case BinaryComparison(lhs: AttributeReference, rhs: 
AttributeReference) =>
+                  val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else 
rhs
+                  val factCol = if (factTable.outputSet.contains(lhs)) lhs 
else rhs
+                  hasStatistics(dimCol, dimTable) && hasStatistics(factCol, 
factTable)
+                case _ => false
+              }
+            }
+
+            if (!areStatsAvailable) {
+              emptyStarJoinPlan
+            } else {
+              // Find the subset of dimension tables. A dimension table is 
assumed to be in a
+              // RI relationship with the fact table. Only consider equi-joins
+              // between a fact and a dimension table to avoid expanding joins.
+              val eligibleDimPlans = allFactJoins.filter { dimTable =>
+                allFactJoinCond.exists {
+                  case cond @ Equality(lhs: AttributeReference, rhs: 
AttributeReference) =>
+                    val dimCol = if (dimTable.outputSet.contains(lhs)) lhs 
else rhs
+                    isUnique(dimCol, dimTable)
+                  case _ => false
+                }
+              }
+
+              if (eligibleDimPlans.isEmpty) {
+                // An eligible star join was not found because the join is not
+                // an RI join, or the star join is an expanding join.
+                emptyStarJoinPlan
+              } else {
+                Seq(factTable +: eligibleDimPlans)
+              }
+            }
+        }
+      }
+    }
+  }
+
+  /**
+   * Reorders a star join based on heuristics:
+   *   1) Finds the star join with the largest fact table and places it on the 
driving
+   *      arm of the left-deep tree. This plan avoids large table access on 
the inner, and
+   *      thus favor hash joins.
+   *   2) Applies the most selective dimensions early in the plan to reduce 
the amount of
+   *      data flow.
+   */
+  def reorderStarJoins(
+      input: Seq[(LogicalPlan, InnerLike)],
+      conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = {
+    assert(input.size >= 2)
+
+    val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)]
+
+    // Find the eligible star plans. Currently, it only returns
+    // the star join with the largest fact table.
+    val eligibleJoins = input.collect{ case (plan, Inner) => plan }
+    val starPlans = findStarJoins(eligibleJoins, conditions)
+
+    if (starPlans.isEmpty) {
+      emptyStarJoinPlan
+    } else {
+      val starPlan = starPlans.head
+      val (factTable, dimTables) = (starPlan.head, starPlan.tail)
+
+      // Only consider selective joins. This case is detected by observing 
local predicates
+      // on the dimension tables. In a star schema relationship, the join 
between the fact and the
+      // dimension table is a FK-PK join. Heuristically, a selective dimension 
may reduce
+      // the result of a join.
+      // Also, conservatively assume that a fact table is joined with more 
than one dimension.
+      if (dimTables.size >= 2 && isSelectiveStarJoin(dimTables, conditions)) {
+        val reorderDimTables = dimTables.map { plan =>
+          TableAccessCardinality(plan, getTableAccessCardinality(plan))
+        }.sortBy(_.size).map {
+          case TableAccessCardinality(p1, _) => p1
+        }
+
+        val reorderStarPlan = factTable +: reorderDimTables
+        reorderStarPlan.map(plan => (plan, Inner))
+      } else {
+        emptyStarJoinPlan
+      }
+    }
+  }
+
+  /**
+   * Determines if a column referenced by a base table access is a primary key.
+   * A column is a PK if it is not nullable and has unique values.
+   * To determine if a column has unique values in the absence of informational
+   * RI constraints, the number of distinct values is compared to the total
+   * number of rows in the table. If their relative difference
+   * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError 
based
+   * on TPCDS data results), the column is assumed to have unique values.
+   */
+  private def isUnique(
+      column: Attribute,
+      plan: LogicalPlan): Boolean = plan match {
+    case PhysicalOperation(_, _, t: LeafNode) =>
+      val leafCol = findLeafNodeCol(column, plan)
+      leafCol match {
+        case Some(col) if t.outputSet.contains(col) =>
+          val stats = t.stats(conf)
+          stats.rowCount match {
+            case Some(rowCount) if rowCount >= 0 =>
+              if (stats.attributeStats.nonEmpty && 
stats.attributeStats.contains(col)) {
+                val colStats = stats.attributeStats.get(col)
+                if (colStats.get.nullCount > 0) {
+                  false
+                } else {
+                  val distinctCount = colStats.get.distinctCount
+                  val relDiff = math.abs((distinctCount.toDouble / 
rowCount.toDouble) - 1.0d)
+                  // ndvMaxErr adjusted based on TPCDS 1TB data results
+                  relDiff <= conf.ndvMaxError * 2
+                }
+              } else {
+                false
+              }
+            case None => false
+          }
+        case None => false
+      }
+    case _ => false
+  }
+
+  /**
+   * Given a column over a base table access, it returns
+   * the leaf node column from which the input column is derived.
+   */
+  @tailrec
+  private def findLeafNodeCol(
+      column: Attribute,
+      plan: LogicalPlan): Option[Attribute] = plan match {
+    case pl @ PhysicalOperation(_, _, _: LeafNode) =>
+      pl match {
+        case t: LeafNode if t.outputSet.contains(column) =>
+          Option(column)
+        case p: Project if p.outputSet.exists(_.semanticEquals(column)) =>
+          val col = p.outputSet.find(_.semanticEquals(column)).get
+          findLeafNodeCol(col, p.child)
+        case f: Filter =>
+          findLeafNodeCol(column, f.child)
+        case _ => None
+      }
+    case _ => None
+  }
+
+  /**
+   * Checks if a column has statistics.
+   * The column is assumed to be over a base table access.
+   */
+  private def hasStatistics(
+      column: Attribute,
+      plan: LogicalPlan): Boolean = plan match {
+    case PhysicalOperation(_, _, t: LeafNode) =>
+      val leafCol = findLeafNodeCol(column, plan)
+      leafCol match {
+        case Some(col) if t.outputSet.contains(col) =>
+          val stats = t.stats(conf)
+          stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
+        case None => false
+      }
+    case _ => false
+  }
+
+  /**
+   * Returns the join predicates between two input plans. It only
+   * considers basic comparison operators.
+   */
+  @inline
+  private def findJoinConditions(
+      plan1: LogicalPlan,
+      plan2: LogicalPlan,
+      conditions: Seq[Expression]): Seq[Expression] = {
+    val refs = plan1.outputSet ++ plan2.outputSet
+    conditions.filter {
+      case BinaryComparison(_, _) => true
+      case _ => false
+    }.filterNot(canEvaluate(_, plan1))
+     .filterNot(canEvaluate(_, plan2))
+     .filter(_.references.subsetOf(refs))
+  }
+
+  /**
+   * Checks if a star join is a selective join. A star join is assumed
+   * to be selective if there are local predicates on the dimension
+   * tables.
+   */
+  private def isSelectiveStarJoin(
+      dimTables: Seq[LogicalPlan],
+      conditions: Seq[Expression]): Boolean = dimTables.exists {
+    case plan @ PhysicalOperation(_, p, _: LeafNode) =>
+      // Checks if any condition applies to the dimension tables.
+      // Exclude the IsNotNull predicates until predicate selectivity is 
available.
+      // In most cases, this predicate is artificially introduced by the 
Optimizer
+      // to enforce nullability constraints.
+      val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull])
+        .exists(canEvaluate(_, plan))
+
+      // Checks if there are any predicates pushed down to the base table 
access.
+      val pushedDownPredicates = p.nonEmpty && 
!p.forall(_.isInstanceOf[IsNotNull])
+
+      localPredicates || pushedDownPredicates
+    case _ => false
+  }
+
+  /**
+   * Helper case class to hold (plan, rowCount) pairs.
+   */
+  private case class TableAccessCardinality(plan: LogicalPlan, size: 
Option[BigInt])
+
+  /**
+   * Returns the cardinality of a base table access. A base table access 
represents
+   * a LeafNode, or Project or Filter operators above a LeafNode.
+   */
+  private def getTableAccessCardinality(
+      input: LogicalPlan): Option[BigInt] = input match {
+    case PhysicalOperation(_, cond, t: LeafNode) if 
t.stats(conf).rowCount.isDefined =>
+      if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) {
+        Option(input.stats(conf).rowCount.get)
+      } else {
+        Option(t.stats(conf).rowCount.get)
+      }
+    case _ => None
+  }
+}
 
 /**
  * Reorder the joins and push all the conditions into join, so that the bottom 
ones have at least
  * one condition.
  *
  * The order of joins will not be changed if all of them already have at least 
one condition.
+ *
+ * If star schema detection is enabled, reorder the star join plans based on 
heuristics.
  */
-object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
-
+case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with 
PredicateHelper {
   /**
    * Join a list of plans together and push down the conditions into them.
    *
@@ -42,7 +370,7 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
    * @param conditions a list of condition for join.
    */
   @tailrec
-  def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: 
Seq[Expression])
+  final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], 
conditions: Seq[Expression])
     : LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
@@ -83,9 +411,19 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case j @ ExtractFiltersAndInnerJoins(input, conditions)
+    case ExtractFiltersAndInnerJoins(input, conditions)
         if input.size > 2 && conditions.nonEmpty =>
-      createOrderedJoin(input, conditions)
+      if (conf.starSchemaDetection && !conf.cboEnabled) {
+        val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, 
conditions)
+        if (starJoinPlan.nonEmpty) {
+          val rest = input.filterNot(starJoinPlan.contains(_))
+          createOrderedJoin(starJoinPlan ++ rest, conditions)
+        } else {
+          createOrderedJoin(input, conditions)
+        }
+      } else {
+        createOrderedJoin(input, conditions)
+      }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 0893af2..d39b0ef 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -167,8 +167,8 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
       : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
     case Join(left, right, joinType: InnerLike, cond) =>
       val (plans, conditions) = flattenJoin(left, joinType)
-      (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq)
-
+      (plans ++ Seq((right, joinType)), conditions ++
+        cond.toSeq.flatMap(splitConjunctivePredicates))
     case Filter(filterCondition, j @ Join(left, right, _: InnerLike, 
joinCondition)) =>
       val (plans, conditions) = flattenJoin(j)
       (plans, conditions ++ splitConjunctivePredicates(filterCondition))

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d2ac4b8..b6e0b8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -719,6 +719,18 @@ object SQLConf {
       .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must 
be in [0, 1].")
       .createWithDefault(0.7)
 
+  val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection")
+    .doc("When true, it enables join reordering based on star schema 
detection. ")
+    .booleanConf
+    .createWithDefault(false)
+
+  val STARSCHEMA_FACT_TABLE_RATIO = buildConf("spark.sql.cbo.starJoinFTRatio")
+    .internal()
+    .doc("Specifies the upper limit of the ratio between the largest fact 
tables" +
+      " for a star join to be considered. ")
+    .doubleConf
+    .createWithDefault(0.9)
+
   val SESSION_LOCAL_TIMEZONE =
     buildConf("spark.sql.session.timeZone")
       .doc("""The ID of session local timezone, e.g. "GMT", 
"America/Los_Angeles", etc.""")
@@ -988,6 +1000,10 @@ class SQLConf extends Serializable with Logging {
 
   def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
 
+  def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)
+
+  def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO)
+
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 985e490..61e8180 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
 import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
-
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
 
 class JoinOptimizationSuite extends PlanTest {
 
@@ -38,7 +38,7 @@ class JoinOptimizationSuite extends PlanTest {
         CombineFilters,
         PushDownPredicate,
         BooleanSimplification,
-        ReorderJoin,
+        ReorderJoin(SimpleCatalystConf(true)),
         PushPredicateThroughJoin,
         ColumnPruning,
         CollapseProject) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index 5607bcd..05b839b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -22,10 +22,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
 import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, 
StatsTestPlan}
-import org.apache.spark.sql.catalyst.util._
 
 
 class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
@@ -38,7 +37,7 @@ class JoinReorderSuite extends PlanTest with 
StatsEstimationTestBase {
       Batch("Operator Optimizations", FixedPoint(100),
         CombineFilters,
         PushDownPredicate,
-        ReorderJoin,
+        ReorderJoin(conf),
         PushPredicateThroughJoin,
         ColumnPruning,
         CollapseProject) ::
@@ -203,27 +202,7 @@ class JoinReorderSuite extends PlanTest with 
StatsEstimationTestBase {
       originalPlan: LogicalPlan,
       groundTruthBestPlan: LogicalPlan): Unit = {
     val optimized = Optimize.execute(originalPlan.analyze)
-    val normalized1 = normalizePlan(normalizeExprIds(optimized))
-    val normalized2 = 
normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze))
-    if (!sameJoinPlan(normalized1, normalized2)) {
-      fail(
-        s"""
-           |== FAIL: Plans do not match ===
-           |${sideBySide(normalized1.treeString, 
normalized2.treeString).mkString("\n")}
-         """.stripMargin)
-    }
-  }
-
-  /** Consider symmetry for joins when comparing plans. */
-  private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
-    (plan1, plan2) match {
-      case (j1: Join, j2: Join) =>
-        (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
-          (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
-      case _ if plan1.children.nonEmpty && plan2.children.nonEmpty =>
-        (plan1.children, plan2.children).zipped.forall { case (c1, c2) => 
sameJoinPlan(c1, c2) }
-      case _ =>
-        plan1 == plan2
-    }
+    val expected = groundTruthBestPlan.analyze
+    compareJoinOrder(optimized, expected)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
new file mode 100644
index 0000000..93fdd98
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
@@ -0,0 +1,580 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, 
StatsTestPlan}
+
+
+class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
+
+  override val conf = SimpleCatalystConf(
+    caseSensitiveAnalysis = true, starSchemaDetection = true)
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Operator Optimizations", FixedPoint(100),
+        CombineFilters,
+        PushDownPredicate,
+        ReorderJoin(conf),
+        PushPredicateThroughJoin,
+        ColumnPruning,
+        CollapseProject) :: Nil
+  }
+
+  // Table setup using star schema relationships:
+  //
+  // d1 - f1 - d2
+  //      |
+  //      d3 - s3
+  //
+  // Table f1 is the fact table. Tables d1, d2, and d3 are the dimension 
tables.
+  // Dimension d3 is further joined/normalized into table s3.
+  // Tables' cardinality: f1 > d3 > d1 > d2 > s3
+  private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+    // F1
+    attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    // D1
+    attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    // D2
+    attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 1, avgLen = 4, maxLen = 4),
+    attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    // D3
+    attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = 
Some(5),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    // S3
+    attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = 
Some(2),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    // F11
+    attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = 
Some(3),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4),
+    attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = 
Some(4),
+      nullCount = 0, avgLen = 4, maxLen = 4)
+  ))
+
+  private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => 
kv._1.name -> kv._1)
+  private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+    columnInfo.map(kv => kv._1.name -> kv)
+
+  private val f1 = StatsTestPlan(
+    outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToAttr),
+    rowCount = 6,
+    size = Some(48),
+    attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", 
"f1_c4").map(nameToColInfo)))
+
+  private val d1 = StatsTestPlan(
+    outputList = Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToAttr),
+    rowCount = 4,
+    size = Some(32),
+    attributeStats = AttributeMap(Seq("d1_pk1", "d1_c2", "d1_c3", 
"d1_c4").map(nameToColInfo)))
+
+  private val d2 = StatsTestPlan(
+    outputList = Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToAttr),
+    rowCount = 3,
+    size = Some(24),
+    attributeStats = AttributeMap(Seq("d2_c2", "d2_pk1", "d2_c3", 
"d2_c4").map(nameToColInfo)))
+
+  private val d3 = StatsTestPlan(
+    outputList = Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToAttr),
+    rowCount = 5,
+    size = Some(40),
+    attributeStats = AttributeMap(Seq("d3_fk1", "d3_c2", "d3_pk1", 
"d3_c4").map(nameToColInfo)))
+
+  private val s3 = StatsTestPlan(
+    outputList = Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToAttr),
+    rowCount = 2,
+    size = Some(17),
+    attributeStats = AttributeMap(Seq("s3_pk1", "s3_c2", "s3_c3", 
"s3_c4").map(nameToColInfo)))
+
+  private val d3_ns = LocalRelation('d3_fk1.int, 'd3_c2.int, 'd3_pk1.int, 
'd3_c4.int)
+
+  private val f11 = StatsTestPlan(
+    outputList = Seq("f11_fk1", "f11_fk2", "f11_fk3", 
"f11_c4").map(nameToAttr),
+    rowCount = 6,
+    size = Some(48),
+    attributeStats = AttributeMap(Seq("f11_fk1", "f11_fk2", "f11_fk3", 
"f11_c4")
+      .map(nameToColInfo)))
+
+  private val subq = d3.select(sum('d3_fk1).as('col))
+
+  test("Test 1: Selective star-join on all dimensions") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      s3 - d3
+    //
+    // Query:
+    //  select f1_fk1, f1_fk3
+    //  from d1, d2, f1, d3, s3
+    //  where f1_fk2 = d2_pk1 and d2_c2 < 2
+    //  and f1_fk1 = d1_pk1
+    //  and f1_fk3 = d3_pk1
+    //  and d3_fk1 = s3_pk1
+    //
+    // Positional join reordering: d1, f1, d2, d3, s3
+    // Star join reordering: f1, d2, d1, d3, s3
+    val query =
+      d1.join(d2).join(f1).join(d3).join(s3)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+          (nameToAttr("d2_c2") === 2) &&
+          (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      f1.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+        .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 2: Star join on a subset of dimensions due to inequality joins") {
+    // Star join:
+    //   (=)  (<)
+    // d1 - f1 - d2
+    //      |
+    //      | (=)
+    //      d3 - s3
+    //        (=)
+    //
+    // Query:
+    //  select f1_fk1, f1_fk3
+    //  from d1, f1, d2, s3, d3
+    //  where f1_fk2 < d2_pk1
+    //  and f1_fk1 = d1_pk1 and d1_c2 = 2
+    //  and f1_fk3 = d3_pk1
+    //  and d3_fk1 = s3_pk1
+    //
+    // Default join reordering: d1, f1, d2, d3, s3
+    // Star join reordering: f1, d1, d3, d2,, d3
+
+    val query =
+      d1.join(f1).join(d2).join(s3).join(d3)
+        .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) &&
+          (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("d1_c2") === 2) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+        .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 3:  Star join on a subset of dimensions since join column is not 
unique") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      d3 - s3
+    //
+    // Query:
+    //  select f1_fk1, f1_fk3
+    //  from d1, f1, d2, s3, d3
+    //  where f1_fk2 = d2_c4
+    //  and f1_fk1 = d1_pk1 and d1_c2 = 2
+    //  and f1_fk3 = d3_pk1
+    //  and d3_fk1 = s3_pk1
+    //
+    // Default join reordering: d1, f1, d2, d3, s3
+    // Star join reordering: f1, d1, d3, d2, d3
+    val query =
+      d1.join(f1).join(d2).join(s3).join(d3)
+        .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("d1_c2") === 2) &&
+          (nameToAttr("f1_fk2") === nameToAttr("d2_c4")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+        .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2")))
+
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 4: Star join on a subset of dimensions since join column is 
nullable") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      s3 - d3
+    //
+    // Query:
+    //  select f1_fk1, f1_fk3
+    //  from d1, f1, d2, s3, d3
+    //  where f1_fk2 = d2_c2
+    //  and f1_fk1 = d1_pk1 and d1_c2 = 2
+    //  and f1_fk3 = d3_pk1
+    //  and d3_fk1 = s3_pk1
+    //
+    // Default join reordering: d1, f1, d2, d3, s3
+    // Star join reordering: f1, d1, d3, d2, s3
+
+    val query =
+      d1.join(f1).join(d2).join(s3).join(d3)
+        .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("d1_c2") === 2) &&
+          (nameToAttr("f1_fk2") === nameToAttr("d2_c2")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+        .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 5: Table stats not available for some of the joined tables") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      d3_ns - s3
+    //
+    //  select f1_fk1, f1_fk3
+    //  from d3_ns, f1, d1, d2, s3
+    //  where f1_fk2 = d2_pk1 and d2_c2 = 2
+    //  and f1_fk1 = d1_pk1
+    //  and f1_fk3 = d3_pk1
+    //  and d3_fk1 = s3_pk1
+    //
+    // Positional join reordering: d3_ns, f1, d1, d2, s3
+    // Star join reordering: empty
+
+    val query =
+      d3_ns.join(f1).join(d1).join(d2).join(s3)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+          (nameToAttr("d2_c2") === 2) &&
+          (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val equivQuery =
+      d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === 
nameToAttr("d3_pk1")))
+        .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, equivQuery)
+  }
+
+  test("Test 6: Join with complex plans") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      (sub-query)
+    //
+    //  select f1_fk1, f1_fk3
+    //  from (select sum(d3_fk1) as col from d3) subq, f1, d1, d2
+    //  where f1_fk2 = d2_pk1 and d2_c2 < 2
+    //  and f1_fk1 = d1_pk1
+    //  and f1_fk3 = sq.col
+    //
+    // Positional join reordering: d3, f1, d1, d2
+    // Star join reordering: empty
+
+    val query =
+      subq.join(f1).join(d1).join(d2)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+          (nameToAttr("d2_c2") === 2) &&
+          (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") === "col".attr))
+
+    val expected =
+      d3.select('d3_fk1).select(sum('d3_fk1).as('col))
+        .join(f1, Inner, Some(nameToAttr("f1_fk3") === "col".attr))
+        .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 7: Comparable fact table sizes") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      f11 - s3
+    //
+    // select f1.f1_fk1, f1.f1_fk3
+    // from d1, f11, f1, d2, s3
+    // where f1.f1_fk2 = d2_pk1 and d2_c2 = 2
+    // and f1.f1_fk1 = d1_pk1
+    // and f1.f1_fk3 = f11.f1_fk3
+    // and f11.f1_fk1 = s3_pk1
+    //
+    // Positional join reordering: d1, f1, f11, d2, s3
+    // Star join reordering: empty
+
+    val query =
+      d1.join(f11).join(f1).join(d2).join(s3)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+          (nameToAttr("d2_c2") === 2) &&
+          (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("f11_fk3")) &&
+          (nameToAttr("f11_fk1") === nameToAttr("s3_pk1")))
+
+    val equivQuery =
+      d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(f11, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("f11_fk3")))
+        .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, equivQuery)
+  }
+
+  test("Test 8: No RI joins") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      d3 - s3
+    //
+    //  select f1_fk1, f1_fk3
+    //  from d1, d3, f1, d2, s3
+    //  where f1_fk2 = d2_c4 and d2_c2 = 2
+    //  and f1_fk1 = d1_c4
+    //  and f1_fk3 = d3_c4
+    //  and d3_fk1 = s3_pk1
+    //
+    // Positional/default join reordering: d1, f1, d3, d2, s3
+    // Star join reordering: empty
+
+    val query =
+      d1.join(d3).join(f1).join(d2).join(s3)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_c4")) &&
+          (nameToAttr("d2_c2") === 2) &&
+          (nameToAttr("f1_fk1") === nameToAttr("d1_c4")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_c4")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_c4")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_c4")))
+        .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 9: Complex join predicates") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      d3 - s3
+    //
+    // select f1_fk1, f1_fk3
+    // from d1, d3, f1, d2, s3
+    // where f1_fk2 = d2_pk1 and d2_c2 = 2
+    // and abs(f1_fk1) = d1_pk1
+    // and f1_fk3 = d3_pk1
+    // and d3_fk1 = s3_pk1
+    //
+    // Positional/default join reordering: d1, f1, d3, d2, s3
+    // Star join reordering: empty
+
+    val query =
+      d1.join(d3).join(f1).join(d2).join(s3)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+          (nameToAttr("d2_c2") === 2) &&
+          (abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      d1.join(f1, Inner, Some(abs(nameToAttr("f1_fk1")) === 
nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+        .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+          Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 10: Less than two dimensions") {
+    // Star join:
+    //   (<)  (=)
+    // d1 - f1 - d2
+    //      |(<)
+    //      d3 - s3
+    //
+    // select f1_fk1, f1_fk3
+    // from d1, d3, f1, d2, s3
+    // where f1_fk2 = d2_pk1 and d2_c2 = 2
+    // and f1_fk1 < d1_pk1
+    // and f1_fk3 < d3_pk1
+    //
+    // Positional join reordering: d1, f1, d3, d2, s3
+    // Star join reordering: empty
+
+    val query =
+      d1.join(d3).join(f1).join(d2).join(s3)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+          (nameToAttr("d2_c2") === 2) &&
+          (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1")))
+        .join(d2.where(nameToAttr("d2_c2") === 2),
+          Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 11: Expanding star join") {
+    // Star join:
+    //   (<)  (<)
+    // d1 - f1 - d2
+    //      | (<)
+    //      d3 - s3
+    //
+    // select f1_fk1, f1_fk3
+    // from d1, d3, f1, d2, s3
+    // where f1_fk2 < d2_pk1
+    // and f1_fk1 < d1_pk1
+    // and f1_fk3 < d3_pk1
+    // and d3_fk1 < s3_pk1
+    //
+    // Positional join reordering: d1, f1, d3, d2, s3
+    // Star join reordering: empty
+
+    val query =
+      d1.join(d3).join(f1).join(d2).join(s3)
+        .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) &&
+          (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+
+    val expected =
+      d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1")))
+        .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  test("Test 12: Non selective star join") {
+    // Star join:
+    //   (=)  (=)
+    // d1 - f1 - d2
+    //      | (=)
+    //      d3 - s3
+    //
+    //  select f1_fk1, f1_fk3
+    //  from d1, d3, f1, d2, s3
+    //  where f1_fk2 = d2_pk1
+    //  and f1_fk1 = d1_pk1
+    //  and f1_fk3 = d3_pk1
+    //  and d3_fk1 = s3_pk1
+    //
+    // Positional join reordering: d1, f1, d3, d2, s3
+    // Star join reordering: empty
+
+    val query =
+      d1.join(d3).join(f1).join(d2).join(s3)
+        .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+          (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+          (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+          (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    val expected =
+      d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+        .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+        .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+        .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+    assertEqualPlans(query, expected)
+  }
+
+  private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit 
= {
+    val optimized = Optimize.execute(plan1.analyze)
+    val expected = plan2.analyze
+    compareJoinOrder(optimized, expected)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/81639115/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 5eb3141..2a9d057 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -106,4 +106,30 @@ abstract class PlanTest extends SparkFunSuite with 
PredicateHelper {
   protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
     comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation))
   }
+
+  /** Fails the test if the join order in the two plans do not match */
+  protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) {
+    val normalized1 = normalizePlan(normalizeExprIds(plan1))
+    val normalized2 = normalizePlan(normalizeExprIds(plan2))
+    if (!sameJoinPlan(normalized1, normalized2)) {
+      fail(
+        s"""
+           |== FAIL: Plans do not match ===
+           |${sideBySide(normalized1.treeString, 
normalized2.treeString).mkString("\n")}
+         """.stripMargin)
+    }
+  }
+
+  /** Consider symmetry for joins when comparing plans. */
+  private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
+    (plan1, plan2) match {
+      case (j1: Join, j2: Join) =>
+        (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
+          (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
+      case _ if plan1.children.nonEmpty && plan2.children.nonEmpty =>
+        (plan1.children, plan2.children).zipped.forall { case (c1, c2) => 
sameJoinPlan(c1, c2) }
+      case _ =>
+        plan1 == plan2
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to