This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ff7ab34a695 [SPARK-39915][SQL] Dataset.repartition(N) may not create N 
partitions Non-AQE part
ff7ab34a695 is described below

commit ff7ab34a6957965d74504ed1d36de21d1ad7319c
Author: ulysses-you <[email protected]>
AuthorDate: Tue Aug 30 14:30:41 2022 +0800

    [SPARK-39915][SQL] Dataset.repartition(N) may not create N partitions 
Non-AQE part
    
    ### What changes were proposed in this pull request?
    
    Skip optimize the root user-specified repartition in 
`PropagateEmptyRelation`.
    
    ### Why are the changes needed?
    
    Spark should preserve the final repatition which can affect the final 
output partition which is user-specified.
    
    For example:
    
    ```scala
    spark.sql("select * from values(1) where 1 < rand()").repartition(1)
    
    // before:
    == Optimized Logical Plan ==
    LocalTableScan <empty>, [col1#0]
    
    // after:
    == Optimized Logical Plan ==
    Repartition 1, true
    +- LocalRelation <empty>, [col1#0]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, the empty plan may change
    
    ### How was this patch tested?
    
    add test
    
    Closes #37706 from ulysses-you/empty.
    
    Authored-by: ulysses-you <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../apache/spark/sql/catalyst/dsl/package.scala    |  3 ++
 .../optimizer/PropagateEmptyRelation.scala         | 42 ++++++++++++++++++++--
 .../optimizer/PropagateEmptyRelationSuite.scala    | 38 ++++++++++++++++++++
 .../adaptive/AQEPropagateEmptyRelation.scala       |  2 +-
 .../org/apache/spark/sql/DataFrameSuite.scala      |  7 ++++
 5 files changed, 88 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 118d3e85b71..86d85abc6f3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -501,6 +501,9 @@ package object dsl {
       def repartition(num: Integer): LogicalPlan =
         Repartition(num, shuffle = true, logicalPlan)
 
+      def repartition(): LogicalPlan =
+        RepartitionByExpression(Seq.empty, logicalPlan, None)
+
       def distribute(exprs: Expression*)(n: Int): LogicalPlan =
         RepartitionByExpression(exprs, logicalPlan, numPartitions = n)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 18c344f10f6..9e864d036ef 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -23,6 +23,7 @@ import 
org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, 
TRUE_OR_FALSE_LITERAL}
 
 /**
@@ -44,6 +45,9 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_
  *     - Generate(Explode) with all empty children. Others like Hive UDTF may 
return results.
  */
 abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with 
CastSupport {
+  // This tag is used to mark a repartition as a root repartition which is 
user-specified
+  private[sql] val ROOT_REPARTITION = TreeNodeTag[Unit]("ROOT_REPARTITION")
+
   protected def isEmpty(plan: LogicalPlan): Boolean = plan match {
     case p: LocalRelation => p.data.isEmpty
     case _ => false
@@ -137,8 +141,13 @@ abstract class PropagateEmptyRelationBase extends 
Rule[LogicalPlan] with CastSup
       case _: GlobalLimit if !p.isStreaming => empty(p)
       case _: LocalLimit if !p.isStreaming => empty(p)
       case _: Offset => empty(p)
-      case _: Repartition => empty(p)
-      case _: RepartitionByExpression => empty(p)
+      case _: RepartitionOperation =>
+        if (p.getTagValue(ROOT_REPARTITION).isEmpty) {
+          empty(p)
+        } else {
+          p.unsetTagValue(ROOT_REPARTITION)
+          p
+        }
       case _: RebalancePartitions => empty(p)
       // An aggregate with non-empty group expression will return one output 
row per group when the
       // input to the aggregate is not empty. If the input to the aggregate is 
empty then all groups
@@ -162,13 +171,40 @@ abstract class PropagateEmptyRelationBase extends 
Rule[LogicalPlan] with CastSup
       case _ => p
     }
   }
+
+  protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
+    case _: Repartition => true
+    case r: RepartitionByExpression
+      if r.optNumPartitions.isDefined || r.partitionExpressions.nonEmpty => 
true
+    case _ => false
+  }
+
+  protected def applyInternal(plan: LogicalPlan): LogicalPlan
+
+  /**
+   * Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so 
this rule can
+   * skip optimize it.
+   */
+  private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan 
match {
+    case p: Project => p.mapChildren(addTagForRootRepartition)
+    case f: Filter => f.mapChildren(addTagForRootRepartition)
+    case r if userSpecifiedRepartition(r) =>
+      r.setTagValue(ROOT_REPARTITION, ())
+      r
+    case _ => plan
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    val planWithTag = addTagForRootRepartition(plan)
+    applyInternal(planWithTag)
+  }
 }
 
 /**
  * This rule runs in the normal optimizer
  */
 object PropagateEmptyRelation extends PropagateEmptyRelationBase {
-  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformUpWithPruning(
+  override protected def applyInternal(p: LogicalPlan): LogicalPlan = 
p.transformUpWithPruning(
     _.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
     commonApplyFunc
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index e39bf0fffb9..fe45e02c67f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -327,4 +327,42 @@ class PropagateEmptyRelationSuite extends PlanTest {
       .fromExternalRows(Seq($"a".int, $"b".int, 
$"window".long.withNullability(false)), Nil)
     comparePlans(Optimize.execute(originalQuery.analyze), expected.analyze)
   }
+
+  test("Propagate empty relation with repartition") {
+    val emptyRelation = LocalRelation($"a".int, $"b".int)
+    comparePlans(Optimize.execute(
+      emptyRelation.repartition(1).sortBy($"a".asc).analyze
+    ), emptyRelation.analyze)
+
+    comparePlans(Optimize.execute(
+      emptyRelation.distribute($"a")(1).sortBy($"a".asc).analyze
+    ), emptyRelation.analyze)
+
+    comparePlans(Optimize.execute(
+      emptyRelation.repartition().analyze
+    ), emptyRelation.analyze)
+
+    comparePlans(Optimize.execute(
+      emptyRelation.repartition(1).sortBy($"a".asc).repartition().analyze
+    ), emptyRelation.analyze)
+  }
+
+  test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
+    val emptyRelation = LocalRelation($"a".int, $"b".int)
+    val p1 = emptyRelation.repartition(1).analyze
+    comparePlans(Optimize.execute(p1), p1)
+
+    val p2 = emptyRelation.repartition(1).select($"a").analyze
+    comparePlans(Optimize.execute(p2), p2)
+
+    val p3 = emptyRelation.repartition(1).where($"a" > rand(1)).analyze
+    comparePlans(Optimize.execute(p3), p3)
+
+    val p4 = emptyRelation.repartition(1).where($"a" > 
rand(1)).select($"a").analyze
+    comparePlans(Optimize.execute(p4), p4)
+
+    val p5 = 
emptyRelation.sortBy("$a".asc).repartition().limit(1).repartition(1).analyze
+    val expected5 = emptyRelation.repartition(1).analyze
+    comparePlans(Optimize.execute(p5), expected5)
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
index bab77515f79..132c919c291 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
@@ -69,7 +69,7 @@ object AQEPropagateEmptyRelation extends 
PropagateEmptyRelationBase {
       empty(j)
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+  override protected def applyInternal(p: LogicalPlan): LogicalPlan = 
p.transformUpWithPruning(
     // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
     // `PropagateEmptyRelationBase.commonApplyFunc`
     // LOGICAL_QUERY_STAGE pattern is matched at 
`PropagateEmptyRelationBase.commonApplyFunc`
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 74b01b691b1..9dca09fa2e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -3419,6 +3419,13 @@ class DataFrameSuite extends QueryTest
       Row(java.sql.Date.valueOf("2020-02-01"), 
java.sql.Date.valueOf("2020-02-01")) ::
         Row(java.sql.Date.valueOf("2020-01-01"), 
java.sql.Date.valueOf("2020-01-02")) :: Nil)
   }
+
+  test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+      val df = spark.sql("select * from values(1) where 1 < 
rand()").repartition(2)
+      assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
+    }
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to