Repository: spark
Updated Branches:
  refs/heads/master 4d5c12aa1 -> 4107cce58


[SPARK-2042] Prevent unnecessary shuffle triggered by take()

This PR implements `take()` on a `SchemaRDD` by inserting a logical limit that 
is followed by a `collect()`. This is also accompanied by adding a catalyst 
optimizer rule for collapsing adjacent limits. Doing so prevents an unnecessary 
shuffle that is sometimes triggered by `take()`.

Author: Sameer Agarwal <sam...@databricks.com>

Closes #1048 from sameeragarwal/master and squashes the following commits:

3eeb848 [Sameer Agarwal] Fixing Tests
1b76ff1 [Sameer Agarwal] Deprecating limit(limitExpr: Expression) in v1.1.0
b723ac4 [Sameer Agarwal] Added limit folding tests
a0ff7c4 [Sameer Agarwal] Adding catalyst rule to fold two consecutive limits
8d42d03 [Sameer Agarwal] Implement trigger() as limit() followed by collect()


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4107cce5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4107cce5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4107cce5

Branch: refs/heads/master
Commit: 4107cce58c41160a0dc20339621eacdf8a8b1191
Parents: 4d5c12a
Author: Sameer Agarwal <sam...@databricks.com>
Authored: Wed Jun 11 12:01:04 2014 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Wed Jun 11 12:01:04 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala |  2 +
 .../sql/catalyst/optimizer/Optimizer.scala      | 13 ++++
 .../catalyst/plans/logical/basicOperators.scala |  4 +-
 .../optimizer/CombiningLimitsSuite.scala        | 71 ++++++++++++++++++++
 .../scala/org/apache/spark/sql/SchemaRDD.scala  | 12 +++-
 5 files changed, 97 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4107cce5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3cf163f..d177339 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -175,6 +175,8 @@ package object dsl {
 
     def where(condition: Expression) = Filter(condition, logicalPlan)
 
+    def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)
+
     def join(
         otherPlan: LogicalPlan,
         joinType: JoinType = Inner,

http://git-wip-us.apache.org/repos/asf/spark/blob/4107cce5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e41fd2d..28d1aa2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.types._
 
 object Optimizer extends RuleExecutor[LogicalPlan] {
   val batches =
+    Batch("Combine Limits", FixedPoint(100),
+      CombineLimits) ::
     Batch("ConstantFolding", FixedPoint(100),
       NullPropagation,
       ConstantFolding,
@@ -362,3 +364,14 @@ object SimplifyCasts extends Rule[LogicalPlan] {
     case Cast(e, dataType) if e.dataType == dataType => e
   }
 }
+
+/**
+ * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into 
one, merging the
+ * expressions into one single expression.
+ */
+object CombineLimits extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case ll @ Limit(le, nl @ Limit(ne, grandChild)) =>
+      Limit(If(LessThan(ne, le), ne, le), grandChild)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4107cce5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d3347b6..b777cf4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -135,9 +135,9 @@ case class Aggregate(
   def references = (groupingExpressions ++ 
aggregateExpressions).flatMap(_.references).toSet
 }
 
-case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
   def output = child.output
-  def references = limit.references
+  def references = limitExpr.references
 }
 
 case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {

http://git-wip-us.apache.org/repos/asf/spark/blob/4107cce5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
new file mode 100644
index 0000000..714f018
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class CombiningLimitsSuite extends OptimizerTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Combine Limit", FixedPoint(2),
+        CombineLimits) ::
+      Batch("Constant Folding", FixedPoint(3),
+        NullPropagation,
+        ConstantFolding,
+        BooleanSimplification) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+  test("limits: combines two limits") {
+    val originalQuery =
+      testRelation
+        .select('a)
+        .limit(10)
+        .limit(5)
+
+    val optimized = Optimize(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .limit(5).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("limits: combines three limits") {
+    val originalQuery =
+      testRelation
+        .select('a)
+        .limit(2)
+        .limit(7)
+        .limit(5)
+
+    val optimized = Optimize(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select('a)
+        .limit(2).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4107cce5/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8855c4e..7ad8edf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -178,14 +178,18 @@ class SchemaRDD(
   def orderBy(sortExprs: SortOrder*): SchemaRDD =
     new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
 
+  @deprecated("use limit with integer argument", "1.1.0")
+  def limit(limitExpr: Expression): SchemaRDD =
+    new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+
   /**
-   * Limits the results by the given expressions.
+   * Limits the results by the given integer.
    * {{{
    *   schemaRDD.limit(10)
    * }}}
    */
-  def limit(limitExpr: Expression): SchemaRDD =
-    new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+  def limit(limitNum: Int): SchemaRDD =
+    new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))
 
   /**
    * Performs a grouping followed by an aggregation.
@@ -374,6 +378,8 @@ class SchemaRDD(
 
   override def collect(): Array[Row] = 
queryExecution.executedPlan.executeCollect()
 
+  override def take(num: Int): Array[Row] = limit(num).collect()
+
   // =======================================================================
   // Base RDD functions that do NOT change schema
   // =======================================================================

Reply via email to