This is an automated email from the ASF dual-hosted git repository.

peter-toth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 408a11519641 [SPARK-57038][SQL] Use `Expression.references` in SPJ 
planning
408a11519641 is described below

commit 408a11519641e0bc07b847b8dd1117cab8891ae0
Author: Peter Toth <[email protected]>
AuthorDate: Tue May 26 10:22:42 2026 +0200

    [SPARK-57038][SQL] Use `Expression.references` in SPJ planning
    
    [SPARK-57038][SQL] Use `Expression.references` in SPJ planning
    
    ### What changes were proposed in this pull request?
    
    - Document `AttributeSet`'s iteration-order contract on the class scaladoc: 
iteration via `iterator` / `foreach` / `flatMap` returns elements in insertion 
order (driven by the underlying `LinkedHashSet`). `toSeq` is called out as the 
explicit exception — it sorts by `(name, exprId.id)` for codegen stability 
(SPARK-18394).
    - Migrate seven SPJ-related uses of `_.collectLeaves()` in 
`partitioning.scala` and `EnsureRequirements.scala` to `_.references` / 
`AttributeSet.fromAttributeSets(...)`. Drops the now-redundant 
`.map(_.asInstanceOf[Attribute])` cast at the `EnsureRequirements:89` site.
    - Update `EnsureRequirementsSuite` synthetic fixtures (`exprA..D`) from 
`Literal(1..4)` to `AttributeReference`s. The literals were stand-ins for 
columns; under the migration's `_.references`-based attribute extraction, 
literal children produce empty `AttributeSet`s and trip the planner's "exactly 
one attribute per partition expression" assertions. Real partitionings can't 
reach those assertions with literal-only transforms because 
`KeyedPartitioning.supportsExpressions`'s `isReferenc [...]
    
    ### Why are the changes needed?
    
    `TreeNode.collectLeaves()` returns every node in the tree where 
`children.isEmpty`, including `Literal`s. SPJ planning has always wanted 
attributes only, but with the existing partition expression layout 
(`TransformExpression.children = [col]`, parameters carried in a sidecar 
`numBucketsOpt: Option[Int]` field), the difference didn't surface.
    
    Follow-up work (e.g. SPARK-50593 / #55885) that puts literal parameters 
directly into `TransformExpression.children` (`bucket(Literal(numBuckets), 
col)`, `truncate(col, Literal(width))`) would otherwise force 
`TransformExpression` to override `collectLeaves` to filter literals, breaking 
the universal `TreeNode.collectLeaves` contract for one expression type.
    
    `Expression.references` already returns attributes only (filtering literals 
and other non-attribute leaves), and its insertion-ordered iteration is exactly 
what positional binding (`RowOrdering.create`, `reorder`, 
`attributes.zip(clustering)`) requires. The per-partition-expression 
single-column rule (enforced by `KeyedPartitioning.supportsExpressions`) 
ensures within-expression dedup never matters here. Documenting the 
iteration-order contract lets these call sites rely on the order  [...]
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Covered by existing test suites that exercise the migrated call sites: 
`KeyGroupedPartitioningSuite`, `EnsureRequirementsSuite`, 
`ProjectedOrderingAndPartitioningSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Opus 4.7
    
    Closes #56088 from peter-toth/SPARK-57038-revisit-collectleaves.
    
    Authored-by: Peter Toth <[email protected]>
    Signed-off-by: Peter Toth <[email protected]>
---
 .../spark/sql/catalyst/expressions/AttributeSet.scala   |  4 ++++
 .../sql/catalyst/plans/physical/partitioning.scala      | 10 +++++-----
 .../sql/execution/exchange/EnsureRequirements.scala     | 17 ++++++++++++-----
 .../execution/exchange/EnsureRequirementsSuite.scala    |  8 ++++----
 4 files changed, 25 insertions(+), 14 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 236380b2c030..d958cba27933 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -65,6 +65,10 @@ object AttributeSet {
  * `AttributeReference("a"...) == AttributeReference("b", ...)`. This tactic 
leads to broken tests,
  * and also makes doing transformations hard (we always try keep older trees 
instead of new ones
  * when the transformation was a no-op).
+ *
+ * Iteration via [[iterator]], [[foreach]], or [[Iterable]]-derived 
combinators (`flatMap`, etc.)
+ * visits elements in insertion order. Note: [[toSeq]] is an explicit 
exception -- it sorts by
+ * `(name, exprId.id)` for stable codegen output, see SPARK-18394.
  */
 class AttributeSet private (private val baseSet: 
mutable.LinkedHashSet[AttributeEquals])
   extends Iterable[Attribute] with Serializable {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index e92bb0f7c0d6..aeacdaec7a8d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -578,13 +578,13 @@ case class KeyedPartitioning(
           c.areAllClusterKeysMatched(expressions)
         } else {
           // We'll need to find leaf attributes from the partition expressions 
first.
-          lazy val attributes = expressions.flatMap(_.collectLeaves())
+          lazy val attributes = 
AttributeSet.fromAttributeSets(expressions.map(_.references))
 
           if (SQLConf.get.v2BucketingAllowKeysSubsetOfPartitionKeys) {
             // check that operation keys (required clustering keys)
             // overlap with partition keys (KeyedPartitioning attributes)
             requiredClustering.exists(x => 
attributes.exists(_.semanticEquals(x))) &&
-              expressions.forall(_.collectLeaves().size == 1)
+              expressions.forall(_.references.size == 1)
           } else if (isNarrowed && !isGrouped) {
             // A narrowed, non-grouped partitioning carries the same skew risk 
as using a subset of
             // partition keys for a join: GroupPartitionsExec will merge 
partitions that held
@@ -1218,9 +1218,9 @@ case class KeyedShuffleSpec(
       distKeyToPos.getOrElseUpdate(distKey.canonicalized, 
mutable.BitSet.empty).add(distKeyPos)
     }
     partitioning.expressions.map { e =>
-      val leaves = e.collectLeaves()
-      assert(leaves.size == 1, s"Expected exactly one child from $e, but found 
${leaves.size}")
-      distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty)
+      val refs = e.references
+      assert(refs.size == 1, s"Expected exactly one child from $e, but found 
${refs.size}")
+      distKeyToPos.getOrElse(refs.head.canonicalized, mutable.BitSet.empty)
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index c5a08e983e61..c632b3d841e6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -86,8 +86,9 @@ case class EnsureRequirements(
                 // Find any KeyedPartitioning that satisfies via 
groupedSatisfies.
                 val satisfyingKeyedPartitioning =
                   groupedSatisfies.orElse(nonGroupedSatisfiesWhenGrouped).get
-                val attrs = 
satisfyingKeyedPartitioning.expressions.flatMap(_.collectLeaves())
-                  .map(_.asInstanceOf[Attribute])
+                // The single-column invariant in 
KeyedPartitioning.supportsExpressions guarantees
+                // one attribute per partition expression.
+                val attrs = 
satisfyingKeyedPartitioning.expressions.flatMap(_.references)
                 val keyRowOrdering = RowOrdering.create(o.ordering, attrs)
                 val keyOrdering = keyRowOrdering.on((t: 
InternalRowComparableWrapper) => t.row)
                 if 
(satisfyingKeyedPartitioning.partitionKeys.sliding(2).forall {
@@ -409,12 +410,16 @@ case class EnsureRequirements(
           .orElse(reorderJoinKeysRecursively(
             leftKeys, rightKeys, leftPartitioning, None))
       case (Some(KeyedPartitioning(clustering, _, _, _)), _) =>
-        val leafExprs = clustering.flatMap(_.collectLeaves())
+        // The single-column invariant in 
KeyedPartitioning.supportsExpressions guarantees one
+        // attribute per partition expression.
+        val leafExprs = clustering.flatMap(_.references)
         reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, 
leftKeys)
             .orElse(reorderJoinKeysRecursively(
               leftKeys, rightKeys, None, rightPartitioning))
       case (_, Some(KeyedPartitioning(clustering, _, _, _))) =>
-        val leafExprs = clustering.flatMap(_.collectLeaves())
+        // The single-column invariant in 
KeyedPartitioning.supportsExpressions guarantees one
+        // attribute per partition expression.
+        val leafExprs = clustering.flatMap(_.references)
         reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, 
rightKeys)
             .orElse(reorderJoinKeysRecursively(
               leftKeys, rightKeys, leftPartitioning, None))
@@ -777,7 +782,9 @@ case class EnsureRequirements(
       partitioning: Partitioning,
       distribution: ClusteredDistribution): Option[KeyedShuffleSpec] = {
     def tryCreate(partitioning: KeyedPartitioning): Option[KeyedShuffleSpec] = 
{
-      val attributes = partitioning.expressions.flatMap(_.collectLeaves())
+      // The single-column invariant in KeyedPartitioning.supportsExpressions 
guarantees one
+      // attribute per partition expression.
+      val attributes = partitioning.expressions.flatMap(_.references)
       val clustering = distribution.clustering
 
       val satisfies = if 
(SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 74b706bce34f..17d00ec055e0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -40,10 +40,10 @@ import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 
 class EnsureRequirementsSuite extends SharedSparkSession {
-  private val exprA = Literal(1)
-  private val exprB = Literal(2)
-  private val exprC = Literal(3)
-  private val exprD = Literal(4)
+  private val exprA = AttributeReference("a", IntegerType)()
+  private val exprB = AttributeReference("b", IntegerType)()
+  private val exprC = AttributeReference("c", IntegerType)()
+  private val exprD = AttributeReference("d", IntegerType)()
 
   private val EnsureRequirements = new EnsureRequirements()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to