This is an automated email from the ASF dual-hosted git repository.
peter-toth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c8528a73d4b7 [SPARK-56877][SQL] Enforce `KeyedPartitioning` invariant
in `PartitioningCollection`
c8528a73d4b7 is described below
commit c8528a73d4b7a205e44b1530e6423243be963af4
Author: Peter Toth <[email protected]>
AuthorDate: Mon May 18 11:08:46 2026 +0200
[SPARK-56877][SQL] Enforce `KeyedPartitioning` invariant in
`PartitioningCollection`
### What changes were proposed in this pull request?
- Add a `require` in `PartitioningCollection` that all `KeyedPartitioning`s
reachable from the collection share the same `partitionKeys` reference (`eq`)
and have matching expression arity. The check walks the partitioning tree via
`foreach` so nested collections are covered.
- Add a smart factory `PartitioningCollection.fromPartitionings` that
interns `partitionKeys` references across `KeyedPartitioning`s. Use this at
sites that combine independently-computed partitionings (joins) where keys are
structurally equal but not reference-equal. The factory uses manual recursion
rather than `transformWithPruning` because `KeyedPartitioning.equals` compares
`partitionKeys` element-wise, which would make `transformWithPruning` discard
the rule's replacement as str [...]
- In `GroupPartitionsExec.outputPartitioning`, hoist `val partitionKeys =
groupedPartitions.map(_._1)` above the `transform` so every rebuilt
`KeyedPartitioning` shares the same `partitionKeys` reference. Drop the ad-hoc
consistency assert (now enforced by `PartitioningCollection`).
- Switch `ShuffledJoin` and `StreamingSymmetricHashJoinExec` to
`PartitioningCollection.fromPartitionings` for their inner-join
`outputPartitioning`.
- Update affected tests to construct collections via `fromPartitionings`.
Rewrite the `SPARK-46367` arity-mismatch test in
`ProjectedOrderingAndPartitioningSuite` since the scenario is now rejected at
`PartitioningCollection` construction rather than inside
`AliasAwareOutputExpression`.
### Why are the changes needed?
The "all `KeyedPartitioning`s in a collection must agree on
`partitionKeys`" invariant already existed informally --
`GroupPartitionsExec.outputPartitioning` had a runtime assert checking `==`,
`AliasAwareOutputExpression.projectKeyedPartitionings` asserted matching arity,
and various consumers relied on the invariant being upheld. Consolidating the
check into the `PartitioningCollection` constructor makes it load-bearing: any
future construction site that violates it fails immediatel [...]
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test suites (`EnsureRequirementsSuite`,
`GroupPartitionsExecSuite`, `ProjectedOrderingAndPartitioningSuite`) updated to
use `PartitioningCollection.fromPartitionings` where they previously
constructed collections from independently-built `KeyedPartitioning`s. The
`SPARK-46367` test was rewritten to assert that the invalid mixed-arity
scenario is rejected at `PartitioningCollection` construction.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code 4.7
Closes #55901 from
peter-toth/SPARK-56877-enforce-keyedpartitioning-invariant-in-collection.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Peter Toth <[email protected]>
---
.../sql/catalyst/plans/physical/partitioning.scala | 56 ++++++++++++++++++++++
.../sql/execution/AliasAwareOutputExpression.scala | 17 ++-----
.../datasources/v2/GroupPartitionsExec.scala | 20 ++------
.../execution/joins/BroadcastHashJoinExec.scala | 2 +-
.../spark/sql/execution/joins/ShuffledJoin.scala | 3 +-
.../join/StreamingSymmetricHashJoinExec.scala | 3 +-
.../ProjectedOrderingAndPartitioningSuite.scala | 25 ++++------
.../datasources/v2/GroupPartitionsExecSuite.scala | 2 +-
.../exchange/EnsureRequirementsSuite.scala | 12 ++---
9 files changed, 87 insertions(+), 53 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index cc50da1f17fd..f331cd124759 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -691,6 +691,12 @@ case class RangePartitioning(ordering: Seq[SortOrder],
numPartitions: Int)
* `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
* in this collection do not need to be equivalent, which is useful for
* Outer Join operators.
+ *
+ * [[KeyedPartitioning]]s within a `PartitioningCollection` describe the same
physical partitioning
+ * and so must share the same `partitionKeys` reference, differing only in
their `expressions` (with
+ * matching arity). Use [[PartitioningCollection.fromPartitionings]] to build
a collection from
+ * independently-computed partitionings (e.g. join `outputPartitioning`); it
interns `partitionKeys`
+ * references (including across nested collections) so the invariant holds.
*/
case class PartitioningCollection(partitionings: Seq[Partitioning])
extends Expression with Partitioning with Unevaluable {
@@ -699,6 +705,26 @@ case class PartitioningCollection(partitionings:
Seq[Partitioning])
partitionings.map(_.numPartitions).distinct.length == 1,
s"PartitioningCollection requires all of its partitionings have the same
numPartitions.")
+ checkKeyedPartitioningInvariant()
+
+ private def checkKeyedPartitioningInvariant(): Unit = {
+ var first: KeyedPartitioning = null
+ foreach {
+ case k: KeyedPartitioning =>
+ if (first == null) {
+ first = k
+ } else {
+ require(k.expressions.length == first.expressions.length,
+ "All KeyedPartitionings in a PartitioningCollection must have
matching expression " +
+ "arity")
+ require(k.partitionKeys eq first.partitionKeys,
+ "All KeyedPartitionings in a PartitioningCollection must share the
same " +
+ "partitionKeys reference")
+ }
+ case _ =>
+ }
+ }
+
override def children: Seq[Expression] = partitionings.collect {
case expr: Expression => expr
}
@@ -730,6 +756,36 @@ case class PartitioningCollection(partitionings:
Seq[Partitioning])
super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection]
}
+object PartitioningCollection {
+ /**
+ * Builds a [[PartitioningCollection]], unifying the `partitionKeys`
reference across all
+ * [[KeyedPartitioning]]s (including those in nested collections). Use this
when combining
+ * independently-computed partitionings (e.g. join `outputPartitioning`)
where
+ * `KeyedPartitioning.partitionKeys` are structurally equal but may not be
reference-equal.
+ *
+ * Note: this can't be implemented with `TreeNode.transform`.
+ */
+ def fromPartitionings(partitionings: Seq[Partitioning]):
PartitioningCollection = {
+ var canonicalKeys: Seq[InternalRowComparableWrapper] = null
+ def intern(p: Partitioning): Partitioning = p match {
+ case k: KeyedPartitioning =>
+ if (canonicalKeys == null) {
+ canonicalKeys = k.partitionKeys
+ k
+ } else if (k.partitionKeys ne canonicalKeys) {
+ require(k.partitionKeys == canonicalKeys,
+ "All KeyedPartitionings in a PartitioningCollection must have
equal partitionKeys")
+ k.copy(partitionKeys = canonicalKeys)
+ } else {
+ k
+ }
+ case pc: PartitioningCollection => new
PartitioningCollection(pc.partitionings.map(intern))
+ case other => other
+ }
+ new PartitioningCollection(partitionings.map(intern))
+ }
+}
+
/**
* Represents a partitioning where rows are collected, transformed and
broadcasted to each
* node in the cluster.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
index 1f2b1d0a585d..b37e1b258e9b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
@@ -40,7 +40,7 @@ trait PartitioningPreservingUnaryExecNode extends
UnaryExecNode
(projectedKPs ++ projectedOthers).take(aliasCandidateLimit) match {
case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
case Seq(p) => p
- case ps => PartitioningCollection(ps)
+ case ps => PartitioningCollection.fromPartitionings(ps)
}
}
@@ -88,22 +88,15 @@ trait PartitioningPreservingUnaryExecNode extends
UnaryExecNode
*
* The resulting [[KeyedPartitioning]]s are the cross-product of the
per-position alternatives
* restricted to the projectable positions. All share the same
`partitionKeys` object (projected
- * to the same subset of positions), preserving the invariant required by
[[GroupPartitionsExec]].
+ * to the same subset of positions), preserving the invariant required by
+ * [[PartitioningCollection]].
*/
private def projectKeyedPartitionings(
kps: Seq[KeyedPartitioning]): LazyList[KeyedPartitioning] = {
if (kps.isEmpty) return LazyList.empty
+ // All input KPs share the same `partitionKeys` reference and matching
arity by the
+ // [[PartitioningCollection]] invariant (the only producer of multi-KP
inputs here).
val numPositions = kps.head.expressions.length
- // The function assumes all input KPs share the same `partitionKeys`,
which implies matching
- // expression arity. This invariant is asserted by [[GroupPartitionsExec]]
and is established
- // by the constructors of [[PartitioningCollection]] feeding this method
(a join's
- // `PartitioningCollection(left.outputPartitioning,
right.outputPartitioning)` combines KPs
- // that have been aligned by [[EnsureRequirements]] to the same join
keys). If the invariant
- // is ever violated upstream, fail early with a clear message instead of
throwing an opaque
- // `IndexOutOfBoundsException` from `kp.expressions(i)` below.
- assert(kps.tail.forall(_.expressions.length == numPositions),
- s"All input KeyedPartitionings must share the same expression arity, " +
- s"but got: ${kps.map(_.expressions.length).mkString(", ")}.")
val alternativesPerPosition: IndexedSeq[LazyList[Expression]] =
if (hasAlias) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
index 264a0e954936..4d87be662293 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
@@ -67,24 +67,14 @@ case class GroupPartitionsExec(
override def outputPartitioning: Partitioning = {
child.outputPartitioning match {
case p: Partitioning with Expression =>
- // There can be multiple `KeyedPartitioning` in an output partitioning
of a join, but they
- // can only differ in `expressions`. `partitionKeys` must match so we
can calculate it only
- // once via `groupedPartitions`.
-
- val keyedPartitionings = p.collect { case k: KeyedPartitioning => k }
- if (keyedPartitionings.size > 1) {
- val first = keyedPartitionings.head
- keyedPartitionings.tail.foreach { k =>
- assert(k.partitionKeys == first.partitionKeys,
- "All KeyedPartitioning nodes must have identical partition keys")
- }
- }
-
+ // There can be multiple `KeyedPartitioning`s in an output
partitioning of a join, but they
+ // can only differ in `expressions`; their `partitionKeys` reference
is shared (enforced by
+ // `PartitioningCollection`), so `groupedPartitions` is computed only
once.
+ val partitionKeys = groupedPartitions.map(_._1)
p.transform {
case k: KeyedPartitioning =>
val projectedExpressions =
joinKeyPositions.fold(k.expressions)(_.map(k.expressions))
- KeyedPartitioning(projectedExpressions,
groupedPartitions.map(_._1),
- isGrouped = isGrouped)
+ KeyedPartitioning(projectedExpressions, partitionKeys, isGrouped =
isGrouped)
}.asInstanceOf[Partitioning]
case o => o
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index e4f18c9144dd..2881aeac55d8 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -84,7 +84,7 @@ case class BroadcastHashJoinExec private(
// constructor prevents that.
case p :: Nil => p
- case ps => PartitioningCollection(ps)
+ case ps => PartitioningCollection.fromPartitionings(ps)
}
case _ => streamedPlan.outputPartitioning
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
index f363156c81e5..3fb968bfea7a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
@@ -46,7 +46,8 @@ trait ShuffledJoin extends JoinCodegenSupport {
override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
- PartitioningCollection(Seq(left.outputPartitioning,
right.outputPartitioning))
+ PartitioningCollection.fromPartitionings(
+ Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter | LeftSingle => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter =>
UnknownPartitioning(left.outputPartitioning.numPartitions)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
index 71a7d4cf56e1..9eca04c98591 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
@@ -242,7 +242,8 @@ case class StreamingSymmetricHashJoinExec(
override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
- PartitioningCollection(Seq(left.outputPartitioning,
right.outputPartitioning))
+ PartitioningCollection.fromPartitionings(
+ Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter =>
UnknownPartitioning(left.outputPartitioning.numPartitions)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
index a38570924620..a70baece7784 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
@@ -387,7 +387,7 @@ class ProjectedOrderingAndPartitioningSuite
val y = AttributeReference("y", IntegerType)()
val yAlias = AttributeReference("y_alias", IntegerType)()
val keys2d = Seq(InternalRow(1, 1), InternalRow(1, 2), InternalRow(2, 1),
InternalRow(2, 2))
- val childPartitioning = PartitioningCollection(Seq(
+ val childPartitioning = PartitioningCollection.fromPartitionings(Seq(
KeyedPartitioning(Seq(x, y), keys2d),
KeyedPartitioning(Seq(x, yAlias), keys2d)))
val child = DummyLeafExecWithPartitioning(
@@ -587,27 +587,20 @@ class ProjectedOrderingAndPartitioningSuite
}
}
- test("SPARK-46367: mixed-arity KeyedPartitionings in input fail with a clear
assertion") {
- // The function assumes all input KPs share the same arity (the invariant
asserted by
- // `GroupPartitionsExec`). Without the assert below, indexing
`kp.expressions(i)` for
- // `i >= kp.expressions.length` would throw an opaque
`IndexOutOfBoundsException`. The assert
- // surfaces the real cause -- an upstream node violated the invariant --
so the bug can be
- // fixed at the producer.
+ test("SPARK-46367: mixed-arity KeyedPartitionings rejected by
PartitioningCollection") {
+ // PartitioningCollection enforces matching expression arity (and shared
partitionKeys
+ // references) across all its KeyedPartitionings, so the invariant
required by
+ // `AliasAwareOutputExpression` cannot be violated by the input.
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()
val keys2d = Seq(InternalRow(1, 1), InternalRow(2, 2))
val keys1d = Seq(InternalRow(1), InternalRow(2))
- val child = DummyLeafExecWithPartitioning(
- output = Seq(x, y),
- partitioning = PartitioningCollection(Seq(
+ val e = intercept[IllegalArgumentException] {
+ PartitioningCollection.fromPartitionings(Seq(
KeyedPartitioning(Seq(x, y), keys2d),
- KeyedPartitioning(Seq(x), keys1d))))
- val project = ProjectExec(Seq(x), child)
- val e = intercept[AssertionError] {
- project.outputPartitioning
+ KeyedPartitioning(Seq(x), keys1d)))
}
- assert(e.getMessage.contains("All input KeyedPartitionings must share the
same expression " +
- "arity"))
+ assert(e.getMessage.contains("partitionKeys"))
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
index 5d2adeb0c00a..51951d68cc60 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
@@ -97,7 +97,7 @@ class GroupPartitionsExecSuite extends SharedSparkSession {
val leftKP = KeyedPartitioning(Seq(exprA), partitionKeys)
val rightKP = KeyedPartitioning(Seq(exprB), partitionKeys)
val child = DummySparkPlan(
- outputPartitioning = PartitioningCollection(Seq(leftKP, rightKP)),
+ outputPartitioning =
PartitioningCollection.fromPartitionings(Seq(leftKP, rightKP)),
outputOrdering = Seq(SortOrder(exprA, Ascending, sameOrderExpressions =
Seq(exprB))))
val gpe = GroupPartitionsExec(child)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 1e35985f5049..74b706bce34f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -821,7 +821,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil,
Seq.empty)
)
plan2 = new DummySparkPlanWithBatchScanChild(
- outputPartitioning = PartitioningCollection(Seq(
+ outputPartitioning = PartitioningCollection.fromPartitionings(Seq(
KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil,
Seq.empty),
KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil,
Seq.empty))
)
@@ -1050,7 +1050,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
// With partition collections
plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning =
- PartitioningCollection(
+ PartitioningCollection.fromPartitionings(
Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil,
leftPartValues),
KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil,
leftPartValues))
)
@@ -1077,13 +1077,13 @@ class EnsureRequirementsSuite extends
SharedSparkSession {
// Nested partition collections
plan2 = new DummySparkPlanWithBatchScanChild(outputPartitioning =
- PartitioningCollection(
+ PartitioningCollection.fromPartitionings(
Seq(
- PartitioningCollection(
+ PartitioningCollection.fromPartitionings(
Seq(
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil,
rightPartValues),
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil,
rightPartValues))),
- PartitioningCollection(
+ PartitioningCollection.fromPartitionings(
Seq(
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) ::
Nil, rightPartValues),
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) ::
Nil, rightPartValues)))
@@ -1539,7 +1539,7 @@ private case class DummyBothKPBinaryExec(left: SparkPlan,
right: SparkPlan)
override def output: Seq[Attribute] = left.output ++ right.output
override def outputOrdering: Seq[SortOrder] = left.outputOrdering
override def outputPartitioning: Partitioning =
- PartitioningCollection(Seq(left.outputPartitioning,
right.outputPartitioning))
+ PartitioningCollection.fromPartitionings(Seq(left.outputPartitioning,
right.outputPartitioning))
override protected def doExecute(): RDD[InternalRow] = null
override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]