This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 18fc8e8e023 [SPARK-39915][SQL][3.3] Dataset.repartition(N) may not 
create N partitions Non-AQE part
18fc8e8e023 is described below

commit 18fc8e8e023868f6e7fab3422c5ce57e690d7834
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Fri Sep 9 14:43:19 2022 -0700

    [SPARK-39915][SQL][3.3] Dataset.repartition(N) may not create N partitions 
Non-AQE part
    
    ### What changes were proposed in this pull request?
    
    backport https://github.com/apache/spark/pull/37706 for branch-3.3
    
    Skip optimize the root user-specified repartition in 
`PropagateEmptyRelation`.
    
    ### Why are the changes needed?
    
    Spark should preserve the final repatition which can affect the final 
output partition which is user-specified.
    
    For example:
    
    ```scala
    spark.sql("select * from values(1) where 1 < rand()").repartition(1)
    
    // before:
    == Optimized Logical Plan ==
    LocalTableScan <empty>, [col1#0]
    
    // after:
    == Optimized Logical Plan ==
    Repartition 1, true
    +- LocalRelation <empty>, [col1#0]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, the empty plan may change
    
    ### How was this patch tested?
    
    add test
    
    Closes #37730 from ulysses-you/empty-3.3.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../apache/spark/sql/catalyst/dsl/package.scala    |  3 ++
 .../optimizer/PropagateEmptyRelation.scala         | 42 ++++++++++++++++++++--
 .../optimizer/PropagateEmptyRelationSuite.scala    | 38 ++++++++++++++++++++
 .../adaptive/AQEPropagateEmptyRelation.scala       |  2 +-
 .../org/apache/spark/sql/DataFrameSuite.scala      |  7 ++++
 5 files changed, 88 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index d47e34b110d..000622187f4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -491,6 +491,9 @@ package object dsl {
       def repartition(num: Integer): LogicalPlan =
         Repartition(num, shuffle = true, logicalPlan)
 
+      def repartition(): LogicalPlan =
+        RepartitionByExpression(Seq.empty, logicalPlan, None)
+
       def distribute(exprs: Expression*)(n: Int): LogicalPlan =
         RepartitionByExpression(exprs, logicalPlan, numPartitions = n)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 2c964fa6da3..f8e2096e443 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -23,6 +23,7 @@ import 
org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, 
TRUE_OR_FALSE_LITERAL}
 
 /**
@@ -44,6 +45,9 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_
  *     - Generate(Explode) with all empty children. Others like Hive UDTF may 
return results.
  */
 abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with 
CastSupport {
+  // This tag is used to mark a repartition as a root repartition which is 
user-specified
+  private[sql] val ROOT_REPARTITION = TreeNodeTag[Unit]("ROOT_REPARTITION")
+
   protected def isEmpty(plan: LogicalPlan): Boolean = plan match {
     case p: LocalRelation => p.data.isEmpty
     case _ => false
@@ -136,8 +140,13 @@ abstract class PropagateEmptyRelationBase extends 
Rule[LogicalPlan] with CastSup
       case _: Sort => empty(p)
       case _: GlobalLimit if !p.isStreaming => empty(p)
       case _: LocalLimit if !p.isStreaming => empty(p)
-      case _: Repartition => empty(p)
-      case _: RepartitionByExpression => empty(p)
+      case _: RepartitionOperation =>
+        if (p.getTagValue(ROOT_REPARTITION).isEmpty) {
+          empty(p)
+        } else {
+          p.unsetTagValue(ROOT_REPARTITION)
+          p
+        }
       case _: RebalancePartitions => empty(p)
       // An aggregate with non-empty group expression will return one output 
row per group when the
       // input to the aggregate is not empty. If the input to the aggregate is 
empty then all groups
@@ -160,13 +169,40 @@ abstract class PropagateEmptyRelationBase extends 
Rule[LogicalPlan] with CastSup
       case _ => p
     }
   }
+
+  protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
+    case _: Repartition => true
+    case r: RepartitionByExpression
+      if r.optNumPartitions.isDefined || r.partitionExpressions.nonEmpty => 
true
+    case _ => false
+  }
+
+  protected def applyInternal(plan: LogicalPlan): LogicalPlan
+
+  /**
+   * Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so 
this rule can
+   * skip optimize it.
+   */
+  private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan 
match {
+    case p: Project => p.mapChildren(addTagForRootRepartition)
+    case f: Filter => f.mapChildren(addTagForRootRepartition)
+    case r if userSpecifiedRepartition(r) =>
+      r.setTagValue(ROOT_REPARTITION, ())
+      r
+    case _ => plan
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    val planWithTag = addTagForRootRepartition(plan)
+    applyInternal(planWithTag)
+  }
 }
 
 /**
  * This rule runs in the normal optimizer
  */
 object PropagateEmptyRelation extends PropagateEmptyRelationBase {
-  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformUpWithPruning(
+  override protected def applyInternal(p: LogicalPlan): LogicalPlan = 
p.transformUpWithPruning(
     _.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
     commonApplyFunc
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index 8277e44458b..72ef8fdd91b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -309,4 +309,42 @@ class PropagateEmptyRelationSuite extends PlanTest {
     val optimized2 = Optimize.execute(plan2)
     comparePlans(optimized2, expected)
   }
+
+  test("Propagate empty relation with repartition") {
+    val emptyRelation = LocalRelation($"a".int, $"b".int)
+    comparePlans(Optimize.execute(
+      emptyRelation.repartition(1).sortBy($"a".asc).analyze
+    ), emptyRelation.analyze)
+
+    comparePlans(Optimize.execute(
+      emptyRelation.distribute($"a")(1).sortBy($"a".asc).analyze
+    ), emptyRelation.analyze)
+
+    comparePlans(Optimize.execute(
+      emptyRelation.repartition().analyze
+    ), emptyRelation.analyze)
+
+    comparePlans(Optimize.execute(
+      emptyRelation.repartition(1).sortBy($"a".asc).repartition().analyze
+    ), emptyRelation.analyze)
+  }
+
+  test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
+    val emptyRelation = LocalRelation($"a".int, $"b".int)
+    val p1 = emptyRelation.repartition(1).analyze
+    comparePlans(Optimize.execute(p1), p1)
+
+    val p2 = emptyRelation.repartition(1).select($"a").analyze
+    comparePlans(Optimize.execute(p2), p2)
+
+    val p3 = emptyRelation.repartition(1).where($"a" > rand(1)).analyze
+    comparePlans(Optimize.execute(p3), p3)
+
+    val p4 = emptyRelation.repartition(1).where($"a" > 
rand(1)).select($"a").analyze
+    comparePlans(Optimize.execute(p4), p4)
+
+    val p5 = 
emptyRelation.sortBy("$a".asc).repartition().limit(1).repartition(1).analyze
+    val expected5 = emptyRelation.repartition(1).analyze
+    comparePlans(Optimize.execute(p5), expected5)
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
index bab77515f79..132c919c291 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
@@ -69,7 +69,7 @@ object AQEPropagateEmptyRelation extends 
PropagateEmptyRelationBase {
       empty(j)
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+  override protected def applyInternal(p: LogicalPlan): LogicalPlan = 
p.transformUpWithPruning(
     // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
     // `PropagateEmptyRelationBase.commonApplyFunc`
     // LOGICAL_QUERY_STAGE pattern is matched at 
`PropagateEmptyRelationBase.commonApplyFunc`
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 43aca31d138..b05d320ca07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -3281,6 +3281,13 @@ class DataFrameSuite extends QueryTest
       Row(java.sql.Date.valueOf("2020-02-01"), 
java.sql.Date.valueOf("2020-02-01")) ::
         Row(java.sql.Date.valueOf("2020-01-01"), 
java.sql.Date.valueOf("2020-01-02")) :: Nil)
   }
+
+  test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+      val df = spark.sql("select * from values(1) where 1 < 
rand()").repartition(2)
+      assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
+    }
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to