This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b01747341918 [SPARK-53401][SQL] Enable Direct Passthrough Partitioning in the DataFrame API b01747341918 is described below commit b01747341918b4ac4e13ec35c9e816fa9239754b Author: Shujing Yang <shujing.y...@databricks.com> AuthorDate: Wed Sep 3 23:13:12 2025 +0800 [SPARK-53401][SQL] Enable Direct Passthrough Partitioning in the DataFrame API ### What changes were proposed in this pull request? Currently, Spark's DataFrame repartition() API only supports hash-based and range-based partitioning strategies. Users who need precise control over which partition each row goes to (similar to RDD's partitionBy with custom partitioners) have no direct way to achieve this at the DataFrame level. This PR introduces a new DataFrame API, `repartitionById(col, numPartitions)`, an API that allows users to directly specify target partition IDs in DataFrame repartitioning operations: ``` // Partition rows based on a computed partition ID val df = spark.range(100).withColumn("partition_id", col("id") % 10) val repartitioned = df.repartitionById($"partition_id", 10) ``` ### Why are the changes needed? Enable precise control over which partition each row goes to (similar to RDD's partitionBy with custom partitioners) at the DataFrame level ### Does this PR introduce _any_ user-facing change? Yes. ``` // Partition rows based on a computed partition ID val df = spark.range(100).withColumn("partition_id", col("id") % 10) val repartitioned = df.repartitionById($"partition_id", 10) ``` ### How was this patch tested? New Unit Tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #52153 from shujingyang-db/direct-partitionId-pass-through. Lead-authored-by: Shujing Yang <shujing.y...@databricks.com> Co-authored-by: Shujing Yang <135740748+shujingyang...@users.noreply.github.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/DirectShufflePartitionID.scala | 46 ++++++ .../plans/logical/basicLogicalOperators.scala | 36 +++-- .../sql/catalyst/plans/physical/partitioning.scala | 41 +++++ .../org/apache/spark/sql/classic/Dataset.scala | 15 ++ .../execution/exchange/ShuffleExchangeExec.scala | 9 ++ .../apache/spark/sql/execution/PlannerSuite.scala | 180 ++++++++++++++++++++- 6 files changed, 313 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DirectShufflePartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DirectShufflePartitionID.scala new file mode 100644 index 000000000000..cc59b3376ce6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DirectShufflePartitionID.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType} + +/** + * Expression that takes a partition ID value and passes it through directly for use in + * shuffle partitioning. This is used with RepartitionByExpression to allow users to + * directly specify target partition IDs. + * + * The child expression must evaluate to an integral type and must not be null. + * The resulting partition ID must be in the range [0, numPartitions). + */ +case class DirectShufflePartitionID(child: Expression) + extends UnaryExpression + with ExpectsInputTypes + with Unevaluable { + + override def dataType: DataType = child.dataType + + override def inputTypes: Seq[AbstractDataType] = IntegerType :: Nil + + override def nullable: Boolean = false + + override val prettyName: String = "direct_shuffle_partition_id" + + override protected def withNewChildInternal(newChild: Expression): DirectShufflePartitionID = + copy(child = newChild) +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index add31448bef7..810f2b027acb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, ShufflePartitionIdPassThrough, SinglePartition} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -1871,19 +1871,29 @@ trait HasPartitionExpressions extends SQLConfHelper { protected def partitioning: Partitioning = if (partitionExpressions.isEmpty) { RoundRobinPartitioning(numPartitions) } else { - val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder]) - require(sortOrder.isEmpty || nonSortOrder.isEmpty, - s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " + - "`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " + - "means `HashPartitioning`. In this case we have:" + - s""" - |SortOrder: $sortOrder - |NonSortOrder: $nonSortOrder - """.stripMargin) - if (sortOrder.nonEmpty) { - RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions) + val directShuffleExprs = partitionExpressions.filter(_.isInstanceOf[DirectShufflePartitionID]) + if (directShuffleExprs.nonEmpty) { + assert(directShuffleExprs.length == 1 && partitionExpressions.length == 1, + s"DirectShufflePartitionID can only be used as a single partition expression, " + + s"but found ${directShuffleExprs.length} DirectShufflePartitionID expressions " + + s"out of ${partitionExpressions.length} total expressions") + ShufflePartitionIdPassThrough( + partitionExpressions.head.asInstanceOf[DirectShufflePartitionID], numPartitions) } else { - HashPartitioning(partitionExpressions, numPartitions) + val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder]) + require(sortOrder.isEmpty || nonSortOrder.isEmpty, + s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of" + + " type `SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`," + + " which means `HashPartitioning`. In this case we have:" + + s""" + |SortOrder: $sortOrder + |NonSortOrder: $nonSortOrder + """.stripMargin) + if (sortOrder.nonEmpty) { + RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions) + } else { + HashPartitioning(partitionExpressions, numPartitions) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 038105f9bfdf..f855483ea3c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -626,6 +626,47 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { * - Creating a partitioning that can be used to re-partition another child, so that to make it * having a compatible partitioning as this node. */ + +/** + * Represents a partitioning where partition IDs are passed through directly from the + * DirectShufflePartitionID expression. This partitioning scheme is used when users + * want to directly control partition placement rather than using hash-based partitioning. + * + * This partitioning maps directly to the PartitionIdPassthrough RDD partitioner. + */ +case class ShufflePartitionIdPassThrough( + expr: DirectShufflePartitionID, + numPartitions: Int) extends Expression with Partitioning with Unevaluable { + + // TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through + def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions)) + + def expressions: Seq[Expression] = expr :: Nil + override def children: Seq[Expression] = expr :: Nil + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { + required match { + // TODO(SPARK-53428): Support Direct Passthrough Partitioning in the Streaming Joins + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + val partitioningExpressions = expr.child :: Nil + if (requireAllClusterKeys) { + c.areAllClusterKeysMatched(partitioningExpressions) + } else { + partitioningExpressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } + case _ => false + } + } + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ShufflePartitionIdPassThrough = + copy(expr = newChildren.head.asInstanceOf[DirectShufflePartitionID]) +} + trait ShuffleSpec { /** * Returns the number of partitions of this shuffle spec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index e2688c7ddab1..c5817e6f5bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -1544,6 +1544,21 @@ class Dataset[T] private[sql]( } } + /** + * Repartitions the Dataset into the given number of partitions using the specified + * partition ID expression. + * + * @param numPartitions the number of partitions to use. + * @param partitionIdExpr the expression to be used as the partition ID. Must be an integer type. + * + * @group typedrel + * @since 4.1.0 + */ + def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] = { + val directShufflePartitionIdCol = Column(DirectShufflePartitionID(partitionIdExpr.expr)) + repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol)) + } + protected def repartitionByRange( numPartitions: Option[Int], partitionExprs: Seq[Column]): Dataset[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 31a3f53eb719..3b8fa821eac7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -344,6 +344,10 @@ object ShuffleExchangeExec { // For HashPartitioning, the partitioning key is already a valid partition ID, as we use // `HashPartitioning.partitionIdExpression` to produce partitioning key. new PartitionIdPassthrough(n) + case ShufflePartitionIdPassThrough(_, n) => + // For ShufflePartitionIdPassThrough, the DirectShufflePartitionID expression directly + // produces partition IDs, so we use PartitionIdPassthrough to pass them through directly. + new PartitionIdPassthrough(n) case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -399,6 +403,11 @@ object ShuffleExchangeExec { case SinglePartition => identity case KeyGroupedPartitioning(expressions, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) + case s: ShufflePartitionIdPassThrough => + // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID + // If the value is null, `InternalRow#getInt` returns 0. + val projection = UnsafeProjection.create(s.partitionIdExpression :: Nil, outputAttributes) + row => projection(row).getInt(0) case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0aeff9a2af01..b926cc192bd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, DataFrame, Row} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.functions._ @@ -1406,6 +1407,183 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { assert(planned.exists(_.isInstanceOf[GlobalLimitExec])) assert(planned.exists(_.isInstanceOf[LocalLimitExec])) } + + test("SPARK-53401: repartitionById - should partition rows to the specified partition ID") { + val numPartitions = 10 + val df = spark.range(100).withColumn("expected_p_id", col("id") % numPartitions) + + val repartitioned = df.repartitionById(numPartitions, $"expected_p_id".cast("int")) + val result = repartitioned.withColumn("actual_p_id", spark_partition_id()) + + assert(result.filter(col("expected_p_id") =!= col("actual_p_id")).count() == 0) + + assert(result.rdd.getNumPartitions == numPartitions) + } + + test("SPARK-53401: repartitionById should handle negative partition ids correctly with pmod") { + val df = spark.range(10).toDF("id") + val repartitioned = df.repartitionById(10, ($"id" - 5).cast("int")) + + // With pmod, negative values should be converted to positive values + // (-5) pmod 10 = 5, (-4) pmod 10 = 6 + val result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect() + + assert(result.forall(row => { + val actualPartitionId = row.getAs[Int]("actual_p_id") + val id = row.getAs[Long]("id") + val expectedPartitionId = { + val mod = (id - 5) % 10 + if (mod < 0) mod + 10 else mod + } + actualPartitionId == expectedPartitionId + })) + } + + test("SPARK-53401: repartitionById should fail analysis for non-integral types") { + val df = spark.range(5).withColumn("s", lit("a")) + checkError( + exception = intercept[AnalysisException] { + df.repartitionById(5, $"s").collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"direct_shuffle_partition_id(s)\"", + "paramIndex" -> "first", + "requiredType" -> "\"INT\"", + "inputType" -> "\"STRING\"", + "inputSql" -> "\"s\"" + ) + ) + } + + test("SPARK-53401: repartitionById should send null partition ids to partition 0") { + val df = spark.range(10).toDF("id") + val partitionExpr = when($"id" < 5, $"id").otherwise(lit(null)).cast("int") + val repartitioned = df.repartitionById(10, partitionExpr) + + val result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect() + + val nullRows = result.filter(_.getAs[Long]("id") >= 5) + assert(nullRows.nonEmpty, "Should have rows with null partition expression") + assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0), + "All null partition id rows should go to partition 0") + + val nonNullRows = result.filter(_.getAs[Long]("id") < 5) + nonNullRows.foreach { row => + val id = row.getAs[Long]("id").toInt + val actualPartitionId = row.getAs[Int]("actual_p_id") + assert(actualPartitionId == id % 10, + s"Row with id=$id should be in partition ${id % 10}, " + + s"but was in partition $actualPartitionId") + } + } + + test("SPARK-53401: repartitionById should not" + + " throw an exception for partition id >= numPartitions") { + val numPartitions = 10 + val df = spark.range(20).toDF("id") + val repartitioned = df.repartitionById(numPartitions, $"id".cast("int")) + + assert(repartitioned.collect().length == 20) + assert(repartitioned.rdd.getNumPartitions == numPartitions) + } + + /** + * A helper function to check the number of shuffle exchanges in a physical plan. + * + * @param df The DataFrame whose physical plan will be examined. + * @param expectedShuffles The expected number of shuffle exchanges. + */ + private def checkShuffleCount(df: DataFrame, expectedShuffles: Int): Unit = { + val plan = df.queryExecution.executedPlan + val shuffles = collect(plan) { + case s: ShuffleExchangeLike => s + case s: BroadcastExchangeLike => s + } + assert( + shuffles.size == expectedShuffles, + s"Expected $expectedShuffles shuffle(s), but found ${shuffles.size} in the plan:\n$plan" + ) + } + + test("SPARK-53401: repartitionById followed by groupBy should only have one shuffle") { + val df = spark.range(100) + .withColumn("id", col("id").cast("int")) + .toDF("id") + val repartitioned = df.repartitionById(10, $"id") + val grouped = repartitioned.groupBy($"id").count() + + checkShuffleCount(grouped, 1) + } + + test("SPARK-53401: groupBy on a superset of partition keys should reuse the shuffle") { + val df = spark.range(100) + .withColumn("id", col("id").cast("int")) + .select($"id" % 10 as "key1", $"id" as "value") + val grouped = df.repartitionById(10, $"key1").groupBy($"key1", lit(1)).count() + checkShuffleCount(grouped, 1) + } + + test("SPARK-53401: shuffle reuse is not affected by spark.sql.shuffle.partitions") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + val df = spark.range(100) + .withColumn("id", col("id").cast("int")) + .select($"id" % 10 as "key", $"id" as "value") + val grouped = df.repartitionById(10, $"key").groupBy($"key").count() + + checkShuffleCount(grouped, 1) + assert(grouped.rdd.getNumPartitions == 10) + } + } + + test("SPARK-53401: join with id pass-through and hash partitioning requires shuffle") { + val df1 = spark.range(100) + .withColumn("id", col("id").cast("int")) + .select($"id" % 10 as "key", $"id" as "v1") + .repartitionById(10, $"key") + + val df2 = spark.range(100) + .withColumn("id", col("id").cast("int")) + .select($"id" % 10 as "key", $"id" as "v2") + .repartition($"key") + + val joined1 = df1.join(df2, "key") + + val grouped = joined1.groupBy("key").count() + + // Total shuffles: one for df1, one broadcast for df2, one for groupBy. + // The groupBy reuse the output partitioning after DirectShufflePartitionID. + checkShuffleCount(grouped, 3) + + val joined2 = df2.join(df1, "key") + + val grouped2 = joined2.groupBy("key").count() + + checkShuffleCount(grouped2, 3) + } + + test("SPARK-53401: shuffle reuse after a join doesn't preserve partitioning") { + val df1 = + spark + .range(100) + .withColumn("id", col("id").cast("int")) + .select($"id" % 10 as "key", $"id" as "v1") + .repartitionById(10, $"key") + val df2 = + spark + .range(100) + .withColumn("id", col("id").cast("int")) + .select($"id" % 10 as "key", $"id" as "v2") + .repartitionById(10, $"key") + + val joined = df1.join(df2, "key") + + val grouped = joined.groupBy("key").count() + + // Total shuffles: one for df1, one for df2, one for groupBy. + // The groupBy reuse the output partitioning after DirectShufflePartitionID. + checkShuffleCount(grouped, 3) + } } // Used for unit-testing EnsureRequirements --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org