Repository: spark
Updated Branches:
  refs/heads/master 79cc59718 -> 2f422398b


[SPARK-25352][SQL] Perform ordered global limit when limit number is bigger 
than topKSortFallbackThreshold

## What changes were proposed in this pull request?

We have optimization on global limit to evenly distribute limit rows across all 
partitions. This optimization doesn't work for ordered results.

For a query ending with sort + limit, in most cases it is performed by 
`TakeOrderedAndProjectExec`.

But if limit number is bigger than `SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD`, 
global limit will be used. At this moment, we need to do ordered global limit.

## How was this patch tested?

Unit tests.

Closes #22344 from viirya/SPARK-25352.

Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f422398
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f422398
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f422398

Branch: refs/heads/master
Commit: 2f422398b524eacc89ab58e423bb134ae3ca3941
Parents: 79cc597
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Wed Sep 12 22:54:05 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Wed Sep 12 22:54:05 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   | 44 ++++++---
 .../org/apache/spark/sql/execution/limit.scala  |  7 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   | 22 ++++-
 .../apache/spark/sql/execution/LimitSuite.scala | 81 +++++++++++++++++
 .../execution/TakeOrderedAndProjectSuite.scala  | 94 +++++++++++---------
 5 files changed, 192 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index dbc6db6..7c8ce31 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -68,22 +68,42 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
   object SpecialLimits extends Strategy {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case ReturnAnswer(rootPlan) => rootPlan match {
-        case Limit(IntegerLiteral(limit), Sort(order, true, child))
-            if limit < conf.topKSortFallbackThreshold =>
-          TakeOrderedAndProjectExec(limit, order, child.output, 
planLater(child)) :: Nil
-        case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, 
true, child)))
-            if limit < conf.topKSortFallbackThreshold =>
-          TakeOrderedAndProjectExec(limit, order, projectList, 
planLater(child)) :: Nil
+        case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) =>
+          if (limit < conf.topKSortFallbackThreshold) {
+            TakeOrderedAndProjectExec(limit, order, child.output, 
planLater(child)) :: Nil
+          } else {
+            GlobalLimitExec(limit,
+              LocalLimitExec(limit, planLater(s)),
+              orderedLimit = true) :: Nil
+          }
+        case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, 
true, child))) =>
+          if (limit < conf.topKSortFallbackThreshold) {
+            TakeOrderedAndProjectExec(limit, order, projectList, 
planLater(child)) :: Nil
+          } else {
+            GlobalLimitExec(limit,
+              LocalLimitExec(limit, planLater(p)),
+              orderedLimit = true) :: Nil
+          }
         case Limit(IntegerLiteral(limit), child) =>
           CollectLimitExec(limit, planLater(child)) :: Nil
         case other => planLater(other) :: Nil
       }
-      case Limit(IntegerLiteral(limit), Sort(order, true, child))
-          if limit < conf.topKSortFallbackThreshold =>
-        TakeOrderedAndProjectExec(limit, order, child.output, 
planLater(child)) :: Nil
-      case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, 
child)))
-          if limit < conf.topKSortFallbackThreshold =>
-        TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) 
:: Nil
+      case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) =>
+        if (limit < conf.topKSortFallbackThreshold) {
+          TakeOrderedAndProjectExec(limit, order, child.output, 
planLater(child)) :: Nil
+        } else {
+          GlobalLimitExec(limit,
+            LocalLimitExec(limit, planLater(s)),
+            orderedLimit = true) :: Nil
+        }
+      case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, 
true, child))) =>
+        if (limit < conf.topKSortFallbackThreshold) {
+          TakeOrderedAndProjectExec(limit, order, projectList, 
planLater(child)) :: Nil
+        } else {
+          GlobalLimitExec(limit,
+            LocalLimitExec(limit, planLater(p)),
+            orderedLimit = true) :: Nil
+        }
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index fb46970..1a09632 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -98,7 +98,8 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) 
extends UnaryExecNode wi
 /**
  * Take the `limit` elements of the child output.
  */
-case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode 
{
+case class GlobalLimitExec(limit: Int, child: SparkPlan,
+                           orderedLimit: Boolean = false) extends 
UnaryExecNode {
 
   override def output: Seq[Attribute] = child.output
 
@@ -126,7 +127,9 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) 
extends UnaryExecNode {
     // When enabled, Spark goes to take rows at each partition repeatedly 
until reaching
     // limit number. When disabled, Spark takes all rows at first partition, 
then rows
     // at second partition ..., until reaching limit number.
-    val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit
+    // The optimization is disabled when it is needed to keep the original 
order of rows
+    // before global sort, e.g., select * from table order by col limit 10.
+    val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit
 
     val shuffled = new ShuffledRowRDD(shuffleDependency)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 279b7b8..f001b13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, 
SparkListenerJobEnd}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.expressions.Uuid
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, 
Union}
-import org.apache.spark.sql.execution.{FilterExec, QueryExecution, 
WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{FilterExec, QueryExecution, 
TakeOrderedAndProjectExec, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.functions._
@@ -2552,6 +2552,26 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     }
   }
 
+  test("SPARK-25352: Ordered global limit when more than 
topKSortFallbackThreshold ") {
+    withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
+      val baseDf = spark.range(1000).toDF.repartition(3).sort("id")
+
+      withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") {
+        val expected = baseDf.limit(99)
+        val takeOrderedNode1 = expected.queryExecution.executedPlan
+          .find(_.isInstanceOf[TakeOrderedAndProjectExec])
+        assert(takeOrderedNode1.isDefined)
+
+        val result = baseDf.limit(100)
+        val takeOrderedNode2 = result.queryExecution.executedPlan
+          .find(_.isInstanceOf[TakeOrderedAndProjectExec])
+        assert(takeOrderedNode2.isEmpty)
+
+        checkAnswer(expected, result.collect().take(99))
+      }
+    }
+  }
+
   test("SPARK-25368 Incorrect predicate pushdown returns wrong result") {
     def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = {
       val df1 = spark.createDataFrame(Seq(

http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala
new file mode 100644
index 0000000..a7840a5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+
+class LimitSuite extends SparkPlanTest with SharedSQLContext {
+
+  private var rand: Random = _
+  private var seed: Long = 0
+
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    seed = System.currentTimeMillis()
+    rand = new Random(seed)
+  }
+
+  test("Produce ordered global limit if more than topKSortFallbackThreshold") {
+    withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") {
+      val df = LimitTest.generateRandomInputData(spark, rand).sort("a")
+
+      val globalLimit = df.limit(99).queryExecution.executedPlan.collect {
+        case g: GlobalLimitExec => g
+      }
+      assert(globalLimit.size == 0)
+
+      val topKSort = df.limit(99).queryExecution.executedPlan.collect {
+        case t: TakeOrderedAndProjectExec => t
+      }
+      assert(topKSort.size == 1)
+
+      val orderedGlobalLimit = 
df.limit(100).queryExecution.executedPlan.collect {
+        case g: GlobalLimitExec => g
+      }
+      assert(orderedGlobalLimit.size == 1 && 
orderedGlobalLimit(0).orderedLimit == true)
+    }
+  }
+
+  test("Ordered global limit") {
+    val baseDf = LimitTest.generateRandomInputData(spark, rand)
+      .select("a").repartition(3).sort("a")
+
+    withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
+      val orderedGlobalLimit = GlobalLimitExec(3, 
baseDf.queryExecution.sparkPlan,
+        orderedLimit = true)
+      val orderedGlobalLimitResult = 
SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext)
+        .map(_.getInt(0))
+
+      val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, 
orderedLimit = false)
+      val globalLimitResult = SparkPlanTest.executePlan(globalLimit, 
spark.sqlContext)
+          .map(_.getInt(0))
+
+      // Global limit without order takes values at each partition 
sequentially.
+      // After global sort, the values in second partition must be larger than 
the values
+      // in first partition.
+      assert(orderedGlobalLimitResult(0) == globalLimitResult(0))
+      assert(orderedGlobalLimitResult(1) < globalLimitResult(1))
+      assert(orderedGlobalLimitResult(2) < globalLimitResult(2))
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/2f422398/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
index f076959..9322204 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
 
 import scala.util.Random
 
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.internal.SQLConf
@@ -32,28 +32,10 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with 
SharedSQLContext {
   private var rand: Random = _
   private var seed: Long = 0
 
-  private val originalLimitFlatGlobalLimit = 
SQLConf.get.getConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT)
-
   protected override def beforeAll(): Unit = {
     super.beforeAll()
     seed = System.currentTimeMillis()
     rand = new Random(seed)
-
-    // Disable the optimization to make Sort-Limit match 
`TakeOrderedAndProject` semantics.
-    SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false)
-  }
-
-  protected override def afterAll() = {
-    SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, 
originalLimitFlatGlobalLimit)
-    super.afterAll()
-  }
-
-  private def generateRandomInputData(): DataFrame = {
-    val schema = new StructType()
-      .add("a", IntegerType, nullable = false)
-      .add("b", IntegerType, nullable = false)
-    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
-    spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 
10), schema)
   }
 
   /**
@@ -66,32 +48,62 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with 
SharedSQLContext {
   val sortOrder = 'a.desc :: 'b.desc :: Nil
 
   test("TakeOrderedAndProject.doExecute without project") {
-    withClue(s"seed = $seed") {
-      checkThatPlansAgree(
-        generateRandomInputData(),
-        input =>
-          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, 
input)),
-        input =>
-          GlobalLimitExec(limit,
-            LocalLimitExec(limit,
-              SortExec(sortOrder, true, input))),
-        sortAnswers = false)
+    withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") {
+      withClue(s"seed = $seed") {
+        checkThatPlansAgree(
+          LimitTest.generateRandomInputData(spark, rand),
+          input =>
+            noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, 
input.output, input)),
+          input =>
+            GlobalLimitExec(limit,
+              LocalLimitExec(limit,
+                SortExec(sortOrder, true, input))),
+          sortAnswers = false)
+      }
     }
   }
 
   test("TakeOrderedAndProject.doExecute with project") {
-    withClue(s"seed = $seed") {
-      checkThatPlansAgree(
-        generateRandomInputData(),
-        input =>
-          noOpFilter(
-            TakeOrderedAndProjectExec(limit, sortOrder, 
Seq(input.output.last), input)),
-        input =>
-          GlobalLimitExec(limit,
-            LocalLimitExec(limit,
-              ProjectExec(Seq(input.output.last),
-                SortExec(sortOrder, true, input)))),
-        sortAnswers = false)
+    withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") {
+      withClue(s"seed = $seed") {
+        checkThatPlansAgree(
+          LimitTest.generateRandomInputData(spark, rand),
+          input =>
+            noOpFilter(
+              TakeOrderedAndProjectExec(limit, sortOrder, 
Seq(input.output.last), input)),
+          input =>
+            GlobalLimitExec(limit,
+              LocalLimitExec(limit,
+                ProjectExec(Seq(input.output.last),
+                  SortExec(sortOrder, true, input)))),
+          sortAnswers = false)
+      }
     }
   }
+
+  test("TakeOrderedAndProject.doExecute equals to ordered global limit") {
+    withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
+      withClue(s"seed = $seed") {
+        checkThatPlansAgree(
+          LimitTest.generateRandomInputData(spark, rand),
+          input =>
+            noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, 
input.output, input)),
+          input =>
+            GlobalLimitExec(limit,
+              LocalLimitExec(limit,
+                SortExec(sortOrder, true, input)), orderedLimit = true),
+          sortAnswers = false)
+      }
+    }
+  }
+}
+
+object LimitTest {
+  def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = {
+    val schema = new StructType()
+      .add("a", IntegerType, nullable = false)
+      .add("b", IntegerType, nullable = false)
+    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
+    
spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 
10), schema)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to