This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 82351bee9bb [SPARK-39915][SQL] Ensure the output partitioning is user-specified in AQE 82351bee9bb is described below commit 82351bee9bba75b631dc3a93aa5bc4cd9f46724c Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Wed Sep 14 08:40:08 2022 +0800 [SPARK-39915][SQL] Ensure the output partitioning is user-specified in AQE ### What changes were proposed in this pull request? - Support get user-specified root repartition through `DeserializeToObjectExec` - Skip optimize empty for the root repartition which is user-specified - Add a new rule `AdjustShuffleExchangePosition` to adjust the shuffle we add back, so that we can restore shuffle safely. ### Why are the changes needed? AQE can not completely respect the user-specified repartition. The main reasons are: 1. the AQE optimzier will convert empty to local relation which does not reserve the partitioning info 2. the machine of AQE `requiredDistribution` only restore the repartition which does not support through `DeserializeToObjectExec` After the fix: The partition number of `spark.range(0).repartition(5).rdd.getNumPartitions` should be 5. ### Does this PR introduce _any_ user-facing change? yes, ensure the user-specified distribution. ### How was this patch tested? add tests Closes #37612 from ulysses-you/output-partition. Lead-authored-by: ulysses-you <ulyssesyo...@gmail.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 801ca252f43b20cdd629c01d734ca9049e6eccf4) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/PropagateEmptyRelation.scala | 23 ++++++++---- .../adaptive/AQEPropagateEmptyRelation.scala | 15 +++++++- .../spark/sql/execution/adaptive/AQEUtils.scala | 3 +- .../execution/adaptive/AdaptiveSparkPlanExec.scala | 1 + .../adaptive/AdjustShuffleExchangePosition.scala | 43 ++++++++++++++++++++++ .../sql/execution/adaptive/LogicalQueryStage.scala | 14 +++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../adaptive/AdaptiveQueryExecSuite.scala | 43 ++++++++++++++++++++++ 8 files changed, 130 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index f8e2096e443..f3606566cb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, REPARTITION_OPERATION, TRUE_OR_FALSE_LITERAL} /** * The base class of two rules in the normal and AQE Optimizer. It simplifies query plans with @@ -183,13 +183,20 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup * Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so this rule can * skip optimize it. */ - private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan match { - case p: Project => p.mapChildren(addTagForRootRepartition) - case f: Filter => f.mapChildren(addTagForRootRepartition) - case r if userSpecifiedRepartition(r) => - r.setTagValue(ROOT_REPARTITION, ()) - r - case _ => plan + private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = { + if (!plan.containsPattern(REPARTITION_OPERATION)) { + return plan + } + + plan match { + case p: Project => p.mapChildren(addTagForRootRepartition) + case f: Filter => f.mapChildren(addTagForRootRepartition) + case d: DeserializeToObject => d.mapChildren(addTagForRootRepartition) + case r if userSpecifiedRepartition(r) => + r.setTagValue(ROOT_REPARTITION, ()) + r + case _ => plan + } } override def apply(plan: LogicalPlan): LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index 132c919c291..7951a6f36b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys /** @@ -33,11 +34,16 @@ import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys */ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { override protected def isEmpty(plan: LogicalPlan): Boolean = - super.isEmpty(plan) || getEstimatedRowCount(plan).contains(0) + super.isEmpty(plan) || (!isRootRepartition(plan) && getEstimatedRowCount(plan).contains(0)) override protected def nonEmpty(plan: LogicalPlan): Boolean = super.nonEmpty(plan) || getEstimatedRowCount(plan).exists(_ > 0) + private def isRootRepartition(plan: LogicalPlan): Boolean = plan match { + case l: LogicalQueryStage if l.getTagValue(ROOT_REPARTITION).isDefined => true + case _ => false + } + // The returned value follows: // - 0 means the plan must produce 0 row // - positive value means an estimated row count which can be over-estimated @@ -69,6 +75,13 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { empty(j) } + override protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match { + case LogicalQueryStage(_, ShuffleQueryStageExec(_, shuffle: ShuffleExchangeLike, _)) + if shuffle.shuffleOrigin == REPARTITION_BY_COL || + shuffle.shuffleOrigin == REPARTITION_BY_NUM => true + case _ => false + } + override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning( // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at // `PropagateEmptyRelationBase.commonApplyFunc` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala index 51833012a12..1a0836ed752 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashPartitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{CollectMetricsExec, FilterExec, ProjectExec, SortExec, SparkPlan} +import org.apache.spark.sql.execution.{CollectMetricsExec, DeserializeToObjectExec, FilterExec, ProjectExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec} object AQEUtils { @@ -41,6 +41,7 @@ object AQEUtils { case f: FilterExec => getRequiredDistribution(f.child) case s: SortExec if !s.global => getRequiredDistribution(s.child) case c: CollectMetricsExec => getRequiredDistribution(c.child) + case d: DeserializeToObjectExec => getRequiredDistribution(d.child) case p: ProjectExec => getRequiredDistribution(p.child).flatMap { case h: ClusteredDistribution => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 6c9c0e1cda4..9b6c98fa0e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -116,6 +116,7 @@ case class AdaptiveSparkPlanExec( Seq( RemoveRedundantProjects, ensureRequirements, + AdjustShuffleExchangePosition, ValidateSparkPlan, ReplaceHashWithSortAgg, RemoveRedundantSorts, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdjustShuffleExchangePosition.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdjustShuffleExchangePosition.scala new file mode 100644 index 00000000000..f211b6cc8a0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdjustShuffleExchangePosition.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{DeserializeToObjectExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + +/** + * This rule is used to adjust the shuffle exchange with special SparkPlan who + * does not allow a shuffle on top of it. + */ +object AdjustShuffleExchangePosition extends Rule[SparkPlan] { + private def shouldAdjust(plan: SparkPlan): Boolean = plan match { + // `DeserializeToObjectExec` is used by Spark internally e.g. `Dataset.rdd`. It produces + // safe rows and must be root node because SQL operators only accept unsafe rows as input. + // This conflicts with AQE framework since we may add shuffle back during re-optimize + // to preserve the user-specified repartition, so here we adjust the position with shuffle. + case _: DeserializeToObjectExec => true + case _ => false + } + + override def apply(plan: SparkPlan): SparkPlan = plan match { + case shuffle: ShuffleExchangeLike if shouldAdjust(shuffle.child) => + shuffle.child.withNewChildren(shuffle.withNewChildren(shuffle.child.children) :: Nil) + case _ => plan + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala index f8b786778a7..5e6f1b5a884 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.trees.TreePattern.{LOGICAL_QUERY_STAGE, TreePattern} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, RepartitionOperation, Statistics} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LOGICAL_QUERY_STAGE, REPARTITION_OPERATION, TreePattern} import org.apache.spark.sql.execution.SparkPlan /** @@ -40,7 +40,15 @@ case class LogicalQueryStage( override def output: Seq[Attribute] = logicalPlan.output override val isStreaming: Boolean = logicalPlan.isStreaming override val outputOrdering: Seq[SortOrder] = physicalPlan.outputOrdering - override protected val nodePatterns: Seq[TreePattern] = Seq(LOGICAL_QUERY_STAGE) + override protected val nodePatterns: Seq[TreePattern] = { + // Repartition is a special node that it represents a shuffle exchange, + // then in AQE the repartition will be always wrapped into `LogicalQueryStage` + val repartitionPattern = logicalPlan match { + case _: RepartitionOperation => Some(REPARTITION_OPERATION) + case _ => None + } + Seq(LOGICAL_QUERY_STAGE) ++ repartitionPattern + } override def computeStats(): Statistics = { // TODO this is not accurate when there is other physical nodes above QueryStageExec. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b05d320ca07..a696c3fd499 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -482,7 +482,7 @@ class DataFrameSuite extends QueryTest testData.select("key").coalesce(1).select("key"), testData.select("key").collect().toSeq) - assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0) + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } test("convert $\"attribute name\" into unresolved attribute") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index d5c933fbc8b..55f092e2d60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2606,6 +2606,49 @@ class AdaptiveQueryExecSuite assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) } } + + test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") { + // partitioning: HashPartitioning + // shuffleOrigin: REPARTITION_BY_NUM + assert(spark.range(0).repartition(5, $"id").rdd.getNumPartitions == 5) + // shuffleOrigin: REPARTITION_BY_COL + // The minimum partition number after AQE coalesce is 1 + assert(spark.range(0).repartition($"id").rdd.getNumPartitions == 1) + // through project + assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") + .repartition(5, $"c1").select($"c2").rdd.getNumPartitions == 5) + + // partitioning: RangePartitioning + // shuffleOrigin: REPARTITION_BY_NUM + // The minimum partition number of RangePartitioner is 1 + assert(spark.range(0).repartitionByRange(5, $"id").rdd.getNumPartitions == 1) + // shuffleOrigin: REPARTITION_BY_COL + assert(spark.range(0).repartitionByRange($"id").rdd.getNumPartitions == 1) + + // partitioning: RoundRobinPartitioning + // shuffleOrigin: REPARTITION_BY_NUM + assert(spark.range(0).repartition(5).rdd.getNumPartitions == 5) + // shuffleOrigin: REBALANCE_PARTITIONS_BY_NONE + assert(spark.range(0).repartition().rdd.getNumPartitions == 0) + // through project + assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") + .repartition(5).select($"c2").rdd.getNumPartitions == 5) + + // partitioning: SinglePartition + assert(spark.range(0).repartition(1).rdd.getNumPartitions == 1) + } + } + + test("SPARK-39915: Ensure the output partitioning is user-specified") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(1).selectExpr("id as c1") + val df2 = spark.range(1).selectExpr("id as c2") + val df = df1.join(df2, col("c1") === col("c2")).repartition(3, col("c1")) + assert(df.rdd.getNumPartitions == 3) + } + } } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org