Repository: spark
Updated Branches:
  refs/heads/master 21ad84623 -> 2cef1bb0b


[SPARK-5354][SQL] Cached tables should preserve partitioning and ord…

…ering.

For cached tables, we can just maintain the partitioning and ordering from the
source relation.

Author: Nong Li <[email protected]>

Closes #9404 from nongli/spark-5354.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2cef1bb0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2cef1bb0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2cef1bb0

Branch: refs/heads/master
Commit: 2cef1bb0b560a03aa7308f694b0c66347b90c9ea
Parents: 21ad846
Author: Nong Li <[email protected]>
Authored: Mon Nov 2 19:18:45 2015 -0800
Committer: Yin Huai <[email protected]>
Committed: Mon Nov 2 19:18:45 2015 -0800

----------------------------------------------------------------------
 .../columnar/InMemoryColumnarTableScan.scala    |  7 +++
 .../apache/spark/sql/execution/Exchange.scala   | 40 ++++++++++---
 .../org/apache/spark/sql/CachedTableSuite.scala | 59 ++++++++++++++++++++
 3 files changed, 97 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2cef1bb0/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index b4607b1..7eb1ad7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -25,6 +25,7 @@ import 
org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan}
 import org.apache.spark.sql.types.UserDefinedType
 import org.apache.spark.storage.StorageLevel
@@ -209,6 +210,12 @@ private[sql] case class InMemoryColumnarTableScan(
 
   override def output: Seq[Attribute] = attributes
 
+  // The cached version does not change the outputPartitioning of the original 
SparkPlan.
+  override def outputPartitioning: Partitioning = 
relation.child.outputPartitioning
+
+  // The cached version does not change the outputOrdering of the original 
SparkPlan.
+  override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering
+
   override def outputsUnsafeRows: Boolean = true
 
   private def statsFor(a: Attribute) = 
relation.partitionStatistics.forAttribute(a)

http://git-wip-us.apache.org/repos/asf/spark/blob/2cef1bb0/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 7f60c8f..e81108b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -194,12 +194,13 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
  */
 private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends 
Rule[SparkPlan] {
   // TODO: Determine the number of partitions.
-  private def numPartitions: Int = sqlContext.conf.numShufflePartitions
+  private def defaultPartitions: Int = sqlContext.conf.numShufflePartitions
 
   /**
    * Given a required distribution, returns a partitioning that satisfies that 
distribution.
    */
-  private def canonicalPartitioning(requiredDistribution: Distribution): 
Partitioning = {
+  private def createPartitioning(requiredDistribution: Distribution,
+      numPartitions: Int): Partitioning = {
     requiredDistribution match {
       case AllTuples => SinglePartition
       case ClusteredDistribution(clustering) => HashPartitioning(clustering, 
numPartitions)
@@ -220,7 +221,7 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
       if (child.outputPartitioning.satisfies(distribution)) {
         child
       } else {
-        Exchange(canonicalPartitioning(distribution), child)
+        Exchange(createPartitioning(distribution, defaultPartitions), child)
       }
     }
 
@@ -229,12 +230,33 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
     if (children.length > 1
         && requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
         && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
-      children = children.zip(requiredChildDistributions).map { case (child, 
distribution) =>
-        val targetPartitioning = canonicalPartitioning(distribution)
-        if (child.outputPartitioning.guarantees(targetPartitioning)) {
-          child
-        } else {
-          Exchange(targetPartitioning, child)
+
+      // First check if the existing partitions of the children all match. 
This means they are
+      // partitioned by the same partitioning into the same number of 
partitions. In that case,
+      // don't try to make them match `defaultPartitions`, just use the 
existing partitioning.
+      // TODO: this should be a cost based descision. For example, a big 
relation should probably
+      // maintain its existing number of partitions and smaller partitions 
should be shuffled.
+      // defaultPartitions is arbitrary.
+      val numPartitions = children.head.outputPartitioning.numPartitions
+      val useExistingPartitioning = 
children.zip(requiredChildDistributions).forall {
+        case (child, distribution) => {
+          child.outputPartitioning.guarantees(
+            createPartitioning(distribution, numPartitions))
+        }
+      }
+
+      children = if (useExistingPartitioning) {
+        children
+      } else {
+        children.zip(requiredChildDistributions).map {
+          case (child, distribution) => {
+            val targetPartitioning = createPartitioning(distribution, 
defaultPartitions)
+            if (child.outputPartitioning.guarantees(targetPartitioning)) {
+              child
+            } else {
+              Exchange(targetPartitioning, child)
+            }
+          }
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2cef1bb0/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index fd566c8..605954b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.execution.Exchange
 import org.apache.spark.sql.execution.PhysicalRDD
 
 import scala.concurrent.duration._
@@ -353,4 +354,62 @@ class CachedTableSuite extends QueryTest with 
SharedSQLContext {
     assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size 
=== 3)
     assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0)
   }
+
+  /**
+   * Verifies that the plan for `df` contains `expected` number of Exchange 
operators.
+   */
+  private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = {
+    assert(df.queryExecution.executedPlan.collect { case e: Exchange => e 
}.size == expected)
+  }
+
+  test("A cached table preserves the partitioning and ordering of its cached 
SparkPlan") {
+    val table3x = testData.unionAll(testData).unionAll(testData)
+    table3x.registerTempTable("testData3x")
+
+    sql("SELECT key, value FROM testData3x ORDER BY 
key").registerTempTable("orderedTable")
+    sqlContext.cacheTable("orderedTable")
+    assertCached(sqlContext.table("orderedTable"))
+    // Should not have an exchange as the query is already sorted on the group 
by key.
+    verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY 
key"), 0)
+    checkAnswer(
+      sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
+      sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY 
key").collect())
+    sqlContext.uncacheTable("orderedTable")
+
+    // Set up two tables distributed in the same way. Try this with the data 
distributed into
+    // different number of partitions.
+    for (numPartitions <- 1 until 10 by 4) {
+      testData.distributeBy(Column("key") :: Nil, 
numPartitions).registerTempTable("t1")
+      testData2.distributeBy(Column("a") :: Nil, 
numPartitions).registerTempTable("t2")
+      sqlContext.cacheTable("t1")
+      sqlContext.cacheTable("t2")
+
+      // Joining them should result in no exchanges.
+      verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = 
t2.a"), 0)
+      checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
+        sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
+
+      // Grouping on the partition key should result in no exchanges
+      verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
+      checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
+        sql("SELECT count(*) FROM testData GROUP BY key"))
+
+      sqlContext.uncacheTable("t1")
+      sqlContext.uncacheTable("t2")
+      sqlContext.dropTempTable("t1")
+      sqlContext.dropTempTable("t2")
+    }
+
+    // Distribute the tables into non-matching number of partitions. Need to 
shuffle.
+    testData.distributeBy(Column("key") :: Nil, 6).registerTempTable("t1")
+    testData2.distributeBy(Column("a") :: Nil, 3).registerTempTable("t2")
+    sqlContext.cacheTable("t1")
+    sqlContext.cacheTable("t2")
+
+    verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 
2)
+    sqlContext.uncacheTable("t1")
+    sqlContext.uncacheTable("t2")
+    sqlContext.dropTempTable("t1")
+    sqlContext.dropTempTable("t2")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to