This is an automated email from the ASF dual-hosted git repository.
peter-toth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 408a11519641 [SPARK-57038][SQL] Use `Expression.references` in SPJ
planning
408a11519641 is described below
commit 408a11519641e0bc07b847b8dd1117cab8891ae0
Author: Peter Toth <[email protected]>
AuthorDate: Tue May 26 10:22:42 2026 +0200
[SPARK-57038][SQL] Use `Expression.references` in SPJ planning
[SPARK-57038][SQL] Use `Expression.references` in SPJ planning
### What changes were proposed in this pull request?
- Document `AttributeSet`'s iteration-order contract on the class scaladoc:
iteration via `iterator` / `foreach` / `flatMap` returns elements in insertion
order (driven by the underlying `LinkedHashSet`). `toSeq` is called out as the
explicit exception — it sorts by `(name, exprId.id)` for codegen stability
(SPARK-18394).
- Migrate seven SPJ-related uses of `_.collectLeaves()` in
`partitioning.scala` and `EnsureRequirements.scala` to `_.references` /
`AttributeSet.fromAttributeSets(...)`. Drops the now-redundant
`.map(_.asInstanceOf[Attribute])` cast at the `EnsureRequirements:89` site.
- Update `EnsureRequirementsSuite` synthetic fixtures (`exprA..D`) from
`Literal(1..4)` to `AttributeReference`s. The literals were stand-ins for
columns; under the migration's `_.references`-based attribute extraction,
literal children produce empty `AttributeSet`s and trip the planner's "exactly
one attribute per partition expression" assertions. Real partitionings can't
reach those assertions with literal-only transforms because
`KeyedPartitioning.supportsExpressions`'s `isReferenc [...]
### Why are the changes needed?
`TreeNode.collectLeaves()` returns every node in the tree where
`children.isEmpty`, including `Literal`s. SPJ planning has always wanted
attributes only, but with the existing partition expression layout
(`TransformExpression.children = [col]`, parameters carried in a sidecar
`numBucketsOpt: Option[Int]` field), the difference didn't surface.
Follow-up work (e.g. SPARK-50593 / #55885) that puts literal parameters
directly into `TransformExpression.children` (`bucket(Literal(numBuckets),
col)`, `truncate(col, Literal(width))`) would otherwise force
`TransformExpression` to override `collectLeaves` to filter literals, breaking
the universal `TreeNode.collectLeaves` contract for one expression type.
`Expression.references` already returns attributes only (filtering literals
and other non-attribute leaves), and its insertion-ordered iteration is exactly
what positional binding (`RowOrdering.create`, `reorder`,
`attributes.zip(clustering)`) requires. The per-partition-expression
single-column rule (enforced by `KeyedPartitioning.supportsExpressions`)
ensures within-expression dedup never matters here. Documenting the
iteration-order contract lets these call sites rely on the order [...]
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Covered by existing test suites that exercise the migrated call sites:
`KeyGroupedPartitioningSuite`, `EnsureRequirementsSuite`,
`ProjectedOrderingAndPartitioningSuite`.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Opus 4.7
Closes #56088 from peter-toth/SPARK-57038-revisit-collectleaves.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Peter Toth <[email protected]>
---
.../spark/sql/catalyst/expressions/AttributeSet.scala | 4 ++++
.../sql/catalyst/plans/physical/partitioning.scala | 10 +++++-----
.../sql/execution/exchange/EnsureRequirements.scala | 17 ++++++++++++-----
.../execution/exchange/EnsureRequirementsSuite.scala | 8 ++++----
4 files changed, 25 insertions(+), 14 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 236380b2c030..d958cba27933 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -65,6 +65,10 @@ object AttributeSet {
* `AttributeReference("a"...) == AttributeReference("b", ...)`. This tactic
leads to broken tests,
* and also makes doing transformations hard (we always try keep older trees
instead of new ones
* when the transformation was a no-op).
+ *
+ * Iteration via [[iterator]], [[foreach]], or [[Iterable]]-derived
combinators (`flatMap`, etc.)
+ * visits elements in insertion order. Note: [[toSeq]] is an explicit
exception -- it sorts by
+ * `(name, exprId.id)` for stable codegen output, see SPARK-18394.
*/
class AttributeSet private (private val baseSet:
mutable.LinkedHashSet[AttributeEquals])
extends Iterable[Attribute] with Serializable {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index e92bb0f7c0d6..aeacdaec7a8d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -578,13 +578,13 @@ case class KeyedPartitioning(
c.areAllClusterKeysMatched(expressions)
} else {
// We'll need to find leaf attributes from the partition expressions
first.
- lazy val attributes = expressions.flatMap(_.collectLeaves())
+ lazy val attributes =
AttributeSet.fromAttributeSets(expressions.map(_.references))
if (SQLConf.get.v2BucketingAllowKeysSubsetOfPartitionKeys) {
// check that operation keys (required clustering keys)
// overlap with partition keys (KeyedPartitioning attributes)
requiredClustering.exists(x =>
attributes.exists(_.semanticEquals(x))) &&
- expressions.forall(_.collectLeaves().size == 1)
+ expressions.forall(_.references.size == 1)
} else if (isNarrowed && !isGrouped) {
// A narrowed, non-grouped partitioning carries the same skew risk
as using a subset of
// partition keys for a join: GroupPartitionsExec will merge
partitions that held
@@ -1218,9 +1218,9 @@ case class KeyedShuffleSpec(
distKeyToPos.getOrElseUpdate(distKey.canonicalized,
mutable.BitSet.empty).add(distKeyPos)
}
partitioning.expressions.map { e =>
- val leaves = e.collectLeaves()
- assert(leaves.size == 1, s"Expected exactly one child from $e, but found
${leaves.size}")
- distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty)
+ val refs = e.references
+ assert(refs.size == 1, s"Expected exactly one child from $e, but found
${refs.size}")
+ distKeyToPos.getOrElse(refs.head.canonicalized, mutable.BitSet.empty)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index c5a08e983e61..c632b3d841e6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -86,8 +86,9 @@ case class EnsureRequirements(
// Find any KeyedPartitioning that satisfies via
groupedSatisfies.
val satisfyingKeyedPartitioning =
groupedSatisfies.orElse(nonGroupedSatisfiesWhenGrouped).get
- val attrs =
satisfyingKeyedPartitioning.expressions.flatMap(_.collectLeaves())
- .map(_.asInstanceOf[Attribute])
+ // The single-column invariant in
KeyedPartitioning.supportsExpressions guarantees
+ // one attribute per partition expression.
+ val attrs =
satisfyingKeyedPartitioning.expressions.flatMap(_.references)
val keyRowOrdering = RowOrdering.create(o.ordering, attrs)
val keyOrdering = keyRowOrdering.on((t:
InternalRowComparableWrapper) => t.row)
if
(satisfyingKeyedPartitioning.partitionKeys.sliding(2).forall {
@@ -409,12 +410,16 @@ case class EnsureRequirements(
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyedPartitioning(clustering, _, _, _)), _) =>
- val leafExprs = clustering.flatMap(_.collectLeaves())
+ // The single-column invariant in
KeyedPartitioning.supportsExpressions guarantees one
+ // attribute per partition expression.
+ val leafExprs = clustering.flatMap(_.references)
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs,
leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyedPartitioning(clustering, _, _, _))) =>
- val leafExprs = clustering.flatMap(_.collectLeaves())
+ // The single-column invariant in
KeyedPartitioning.supportsExpressions guarantees one
+ // attribute per partition expression.
+ val leafExprs = clustering.flatMap(_.references)
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs,
rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
@@ -777,7 +782,9 @@ case class EnsureRequirements(
partitioning: Partitioning,
distribution: ClusteredDistribution): Option[KeyedShuffleSpec] = {
def tryCreate(partitioning: KeyedPartitioning): Option[KeyedShuffleSpec] =
{
- val attributes = partitioning.expressions.flatMap(_.collectLeaves())
+ // The single-column invariant in KeyedPartitioning.supportsExpressions
guarantees one
+ // attribute per partition expression.
+ val attributes = partitioning.expressions.flatMap(_.references)
val clustering = distribution.clustering
val satisfies = if
(SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 74b706bce34f..17d00ec055e0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -40,10 +40,10 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class EnsureRequirementsSuite extends SharedSparkSession {
- private val exprA = Literal(1)
- private val exprB = Literal(2)
- private val exprC = Literal(3)
- private val exprD = Literal(4)
+ private val exprA = AttributeReference("a", IntegerType)()
+ private val exprB = AttributeReference("b", IntegerType)()
+ private val exprC = AttributeReference("c", IntegerType)()
+ private val exprD = AttributeReference("d", IntegerType)()
private val EnsureRequirements = new EnsureRequirements()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]