This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 421b8001107d [SPARK-56903][SQL] Spread NULL outer join keys across 
shuffle partitions
421b8001107d is described below

commit 421b8001107d73f0ea3c9c1313ee878124c37142
Author: Chao Sun <[email protected]>
AuthorDate: Fri May 22 09:40:05 2026 -0700

    [SPARK-56903][SQL] Spread NULL outer join keys across shuffle partitions
    
    ### What changes were proposed in this pull request?
    
    This PR reduces shuffle skew for null-heavy shuffled outer equi-joins.
    
    For `LEFT OUTER`, `RIGHT OUTER`, and `FULL OUTER` joins, preserved rows 
with a `NULL`
    shuffle key may not need to stay concentrated on one reducer. Today those 
rows can all
    collapse into the same shuffle partition, which creates avoidable skew on 
NULL-heavy inputs.
    
    This change adds a feature-flagged null-aware shuffle partitioning mode for 
shuffled outer
    joins:
    
    - Non-NULL shuffle keys keep the existing hash partitioning behavior.
    - Rows with any `NULL` shuffle key are spread across reducers instead of 
collapsing into one
      partition.
    - The behavior is disabled by default behind
      `spark.sql.shuffle.spreadNullJoinKeys.enabled`.
    - The optimization is considered only for `LEFT OUTER`, `RIGHT OUTER`, and 
`FULL OUTER`
      equi-joins whose preserved side has nullable join keys.
    
    Spreading remains result-safe for null-safe equality (`<=>`) outer joins:
    
    - For ordinary extracted `<=>` join keys, Spark rewrites them into non-null 
shuffle-key
      expressions using `coalesce(...)` and `isnull(...)`, so there are no 
`NULL` shuffle keys for
      this feature to redistribute.
    - The only remaining corner is `NullType`, where the shuffle key can still 
be `NULL`. In that
      case, shuffled join execution already treats the row as unmatched, so 
redistributing those
      rows does not change query results.
    
    The implementation wires this through the planner and runtime pieces that 
need to understand
    the new partitioning contract:
    
    - `ClusteredDistribution` can opt into null-aware spreading.
    - New null-aware partitioning and shuffle-spec variants preserve 
compatibility checks without
      pretending to satisfy ordinary clustered distributions.
    - Shuffle execution spreads unmatched `NULL` keys while preserving retry 
safety.
    - AQE/coalesced shuffle reads preserve the new partitioning shape.
    
    When the feature flag is enabled, the null-aware join output partitioning 
intentionally does not
    satisfy a strict `ClusteredDistribution`. That can require an extra 
downstream shuffle for
    grouping, windowing, or another equi-join on the same key. Also, if one 
side is already hash
    partitioned, only the other side may be reshuffled into the null-aware 
layout, so the
    pre-shuffled side can keep its NULL skew.
    
    This PR intentionally stays scoped to outer joins. Left anti joins may also 
have skewed
    preserved-side `NULL` rows for ordinary `=` predicates and are worth 
evaluating separately, but
    they need their own correctness and planning review rather than being 
folded into this patch.
    
    ### Why are the changes needed?
    
    Outer joins can preserve large numbers of unmatched rows from the outer 
side. When many of those
    rows have `NULL` shuffle keys, sending them all to one reducer creates skew 
even though they do
    not require one shared reducer for correctness.
    
    Example:
    
    ```sql
    SELECT *
    FROM fact f
    LEFT OUTER JOIN dim d
      ON f.k = d.k
    ```
    
    If `fact.k` contains many `NULL` values, those rows must remain in the 
result as unmatched
    left-side rows, but they do not need to be grouped together for 
correctness. Spreading them
    reduces needless reducer concentration while leaving normal key matching 
unchanged.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, in execution behavior only. Query results are unchanged, but when the 
feature flag is
    enabled, shuffle partitioning for eligible NULL-heavy outer equi-joins 
becomes less skewed.
    
    ### How was this patch tested?
    
    - Added and updated unit tests covering outer-join planning, FULL OUTER 
JOIN result correctness
      with `NULL` keys, null-safe outer-join behavior, shuffle-level `NULL` 
spreading, retry
      determinism, shuffle-spec compatibility, and AQE preservation of 
null-aware coalesced reads.
    - Ran focused plan-stability verification for the affected TPC-DS cases 
locally.
    - Ran `git diff --check`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Codex GPT-5
    
    Closes #55927 from sunchao/dev/chao/codex/null-aware-outer-join-apache.
    
    Authored-by: Chao Sun <[email protected]>
    Signed-off-by: Chao Sun <[email protected]>
---
 .../sql/catalyst/plans/physical/partitioning.scala | 216 +++++++++++++-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  14 +
 .../spark/sql/catalyst/ShuffleSpecSuite.scala      |  60 ++++
 .../execution/adaptive/AQEShuffleReadExec.scala    |   9 +-
 .../execution/exchange/ShuffleExchangeExec.scala   |  47 ++-
 .../spark/sql/execution/joins/ShuffledJoin.scala   |  22 ++
 .../DistributionAndOrderingSuiteBase.scala         |   2 +-
 .../connector/KeyGroupedPartitioningSuite.scala    |   2 +-
 .../apache/spark/sql/execution/ExchangeSuite.scala |  37 ++-
 .../adaptive/AdaptiveQueryExecSuite.scala          | 151 +++++++---
 .../spark/sql/execution/joins/OuterJoinSuite.scala | 318 ++++++++++++++++++++-
 11 files changed, 806 insertions(+), 72 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index f331cd124759..e92bb0f7c0d6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -81,12 +81,17 @@ case object AllTuples extends Distribution {
  *
  * @param requireAllClusterKeys When true, `Partitioning` which satisfies this 
distribution,
  *                              must match all `clustering` expressions in the 
same ordering.
+ * @param allowNullKeySpreading When true, the default partitioning may spread 
rows whose
+ *                              clustering keys contain NULL values. This is a 
permission for
+ *                              consumers that do not require NULL-key 
co-location; ordinary
+ *                              [[HashPartitioning]] can still satisfy this 
distribution.
  */
 case class ClusteredDistribution(
     clustering: Seq[Expression],
     requireAllClusterKeys: Boolean = SQLConf.get.getConf(
       SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION),
-    requiredNumPartitions: Option[Int] = None) extends Distribution {
+    requiredNumPartitions: Option[Int] = None,
+    allowNullKeySpreading: Boolean = false) extends Distribution {
   require(
     clustering != Nil,
     "The clustering expressions of a ClusteredDistribution should not be Nil. 
" +
@@ -97,7 +102,11 @@ case class ClusteredDistribution(
     assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == 
numPartitions,
       s"This ClusteredDistribution requires ${requiredNumPartitions.get} 
partitions, but " +
         s"the actual number of partitions is $numPartitions.")
-    HashPartitioning(clustering, numPartitions)
+    if (allowNullKeySpreading) {
+      NullAwareHashPartitioning(clustering, numPartitions)
+    } else {
+      HashPartitioning(clustering, numPartitions)
+    }
   }
 
   /**
@@ -282,7 +291,7 @@ trait HashPartitioningLike extends Expression with 
Partitioning with Unevaluable
           expressions.length == h.expressions.length && 
expressions.zip(h.expressions).forall {
             case (l, r) => l.semanticEquals(r)
           }
-        case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _) =>
+        case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _, _) =>
           if (requireAllClusterKeys) {
             // Checks `HashPartitioning` is partitioned on exactly same 
clustering keys of
             // `ClusteredDistribution`.
@@ -324,6 +333,45 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
     newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions 
= newChildren)
 }
 
+/**
+ * Represents a hash partitioning for equi-join inputs where rows with a NULL 
join key do not need
+ * to be co-located. Non-NULL join keys preserve the same partitioning 
contract as
+ * [[HashPartitioning]], while rows with any NULL join key may be spread 
across partitions. As a
+ * result, this partitioning intentionally does not satisfy a strict 
[[ClusteredDistribution]].
+ */
+case class NullAwareHashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
+  extends HashPartitioningLike {
+
+  override def satisfies0(required: Distribution): Boolean = {
+    (required match {
+      case UnspecifiedDistribution => true
+      case AllTuples => numPartitions == 1
+      case _ => false
+    }) || {
+      // Stateful operators require strict NULL-key co-location and therefore 
cannot consume
+      // null-aware hash partitioning as a compatible clustered layout.
+      required match {
+        case c @ ClusteredDistribution(
+            requiredClustering, requireAllClusterKeys, _, 
allowNullKeySpreading)
+            if allowNullKeySpreading =>
+          if (requireAllClusterKeys) {
+            c.areAllClusterKeysMatched(expressions)
+          } else {
+            expressions.forall(x => 
requiredClustering.exists(_.semanticEquals(x)))
+          }
+        case _ => false
+      }
+    }
+  }
+
+  override def createShuffleSpec(distribution: ClusteredDistribution): 
ShuffleSpec =
+    NullAwareHashShuffleSpec(this, distribution)
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): NullAwareHashPartitioning =
+    copy(expressions = newChildren)
+}
+
 case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int)
 
 /**
@@ -345,6 +393,47 @@ case class CoalescedHashPartitioning(from: 
HashPartitioning, partitions: Seq[Coa
     copy(from = from.copy(expressions = newChildren))
 }
 
+/**
+ * Represents a null-aware hash partitioning whose reducer ranges have been 
coalesced into fewer
+ * partitions. It preserves the same relaxed NULL-key co-location contract as
+ * [[NullAwareHashPartitioning]].
+ */
+case class CoalescedNullAwareHashPartitioning(
+    from: NullAwareHashPartitioning,
+    partitions: Seq[CoalescedBoundary]) extends HashPartitioningLike {
+
+  override def expressions: Seq[Expression] = from.expressions
+
+  override def satisfies0(required: Distribution): Boolean = {
+    (required match {
+      case UnspecifiedDistribution => true
+      case AllTuples => numPartitions == 1
+      case _ => false
+    }) || {
+      required match {
+        case c @ ClusteredDistribution(
+            requiredClustering, requireAllClusterKeys, _, 
allowNullKeySpreading)
+            if allowNullKeySpreading =>
+          if (requireAllClusterKeys) {
+            c.areAllClusterKeysMatched(expressions)
+          } else {
+            expressions.forall(x => 
requiredClustering.exists(_.semanticEquals(x)))
+          }
+        case _ => false
+      }
+    }
+  }
+
+  override def createShuffleSpec(distribution: ClusteredDistribution): 
ShuffleSpec =
+    CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions)
+
+  override val numPartitions: Int = partitions.length
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): CoalescedNullAwareHashPartitioning 
=
+    copy(from = from.copy(expressions = newChildren))
+}
+
 /**
  * Represents a partitioning where rows are split across partitions based on 
transforms defined by
  * `expressions`.
@@ -482,7 +571,7 @@ case class KeyedPartitioning(
 
   def groupedSatisfies(required: Distribution): Boolean = {
     required match {
-      case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _) =>
+      case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _, _) =>
         if (requireAllClusterKeys) {
           // Checks whether this partitioning is partitioned on exactly same 
clustering keys of
           // `ClusteredDistribution`.
@@ -657,7 +746,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
           //   `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, 
b)`.
           val minSize = Seq(requiredOrdering.size, ordering.size).min
           requiredOrdering.take(minSize) == ordering.take(minSize)
-        case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _) =>
+        case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _, _) =>
           val expressions = ordering.map(_.child)
           if (requireAllClusterKeys) {
             // Checks `RangePartitioning` is partitioned on exactly same 
clustering keys of
@@ -838,7 +927,7 @@ case class ShufflePartitionIdPassThrough(
     super.satisfies0(required) || {
       required match {
         // TODO(SPARK-53428): Support Direct Passthrough Partitioning in the 
Streaming Joins
-        case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _) =>
+        case c @ ClusteredDistribution(requiredClustering, 
requireAllClusterKeys, _, _) =>
           val partitioningExpressions = expr.child :: Nil
           if (requireAllClusterKeys) {
             c.areAllClusterKeysMatched(partitioningExpressions)
@@ -919,6 +1008,25 @@ case class RangeShuffleSpec(
   }
 }
 
+private object HashShuffleSpecCompatibility {
+  def isCompatible(
+      leftDistribution: ClusteredDistribution,
+      leftNumPartitions: Int,
+      leftExpressions: Seq[Expression],
+      leftHashKeyPositions: Seq[mutable.BitSet],
+      rightDistribution: ClusteredDistribution,
+      rightNumPartitions: Int,
+      rightExpressions: Seq[Expression],
+      rightHashKeyPositions: Seq[mutable.BitSet]): Boolean = {
+    leftDistribution.clustering.length == rightDistribution.clustering.length 
&&
+    leftNumPartitions == rightNumPartitions &&
+    leftExpressions.length == rightExpressions.length &&
+    leftHashKeyPositions.zip(rightHashKeyPositions).forall { case (left, 
right) =>
+      left.intersect(right).nonEmpty
+    }
+  }
+}
+
 case class HashShuffleSpec(
     partitioning: HashPartitioning,
     distribution: ClusteredDistribution) extends ShuffleSpec {
@@ -951,14 +1059,26 @@ case class HashShuffleSpec(
       //  3. both partitioning have the same number of expressions
       //  4. each pair of partitioning expression from both sides has 
overlapping positions in their
       //     corresponding distributions.
-      distribution.clustering.length == otherDistribution.clustering.length &&
-      partitioning.numPartitions == otherPartitioning.numPartitions &&
-      partitioning.expressions.length == otherPartitioning.expressions.length 
&& {
-        val otherHashKeyPositions = otherHashSpec.hashKeyPositions
-        hashKeyPositions.zip(otherHashKeyPositions).forall { case (left, 
right) =>
-          left.intersect(right).nonEmpty
-        }
-      }
+      HashShuffleSpecCompatibility.isCompatible(
+        distribution,
+        partitioning.numPartitions,
+        partitioning.expressions,
+        hashKeyPositions,
+        otherDistribution,
+        otherPartitioning.numPartitions,
+        otherPartitioning.expressions,
+        otherHashSpec.hashKeyPositions)
+    case otherNullAwareSpec @ NullAwareHashShuffleSpec(otherPartitioning, 
otherDistribution)
+        if distribution.allowNullKeySpreading && 
otherDistribution.allowNullKeySpreading =>
+      HashShuffleSpecCompatibility.isCompatible(
+        distribution,
+        partitioning.numPartitions,
+        partitioning.expressions,
+        hashKeyPositions,
+        otherDistribution,
+        otherPartitioning.numPartitions,
+        otherPartitioning.expressions,
+        otherNullAwareSpec.hashKeyPositions)
     case ShuffleSpecCollection(specs) =>
       specs.exists(isCompatibleWith)
     case _ =>
@@ -979,7 +1099,73 @@ case class HashShuffleSpec(
 
   override def createPartitioning(clustering: Seq[Expression]): Partitioning = 
{
     val exprs = hashKeyPositions.map(v => clustering(v.head))
-    HashPartitioning(exprs, partitioning.numPartitions)
+    if (distribution.allowNullKeySpreading) {
+      NullAwareHashPartitioning(exprs, partitioning.numPartitions)
+    } else {
+      HashPartitioning(exprs, partitioning.numPartitions)
+    }
+  }
+
+  override def numPartitions: Int = partitioning.numPartitions
+}
+
+/**
+ * Shuffle specification for [[NullAwareHashPartitioning]]. It is compatible 
only with shuffle
+ * layouts whose distributions explicitly allow NULL-key spreading.
+ */
+case class NullAwareHashShuffleSpec(
+    partitioning: NullAwareHashPartitioning,
+    distribution: ClusteredDistribution) extends ShuffleSpec {
+
+  lazy val hashKeyPositions: Seq[mutable.BitSet] = {
+    val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
+    distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) 
=>
+      distKeyToPos.getOrElseUpdate(distKey.canonicalized, 
mutable.BitSet.empty).add(distKeyPos)
+    }
+    partitioning.expressions.map(k => distKeyToPos.getOrElse(k.canonicalized, 
mutable.BitSet.empty))
+  }
+
+  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
+    case SinglePartitionShuffleSpec =>
+      partitioning.numPartitions == 1
+    case otherSpec @ NullAwareHashShuffleSpec(otherPartitioning, 
otherDistribution) =>
+      HashShuffleSpecCompatibility.isCompatible(
+        distribution,
+        partitioning.numPartitions,
+        partitioning.expressions,
+        hashKeyPositions,
+        otherDistribution,
+        otherPartitioning.numPartitions,
+        otherPartitioning.expressions,
+        otherSpec.hashKeyPositions)
+    case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherDistribution)
+        if distribution.allowNullKeySpreading && 
otherDistribution.allowNullKeySpreading =>
+      HashShuffleSpecCompatibility.isCompatible(
+        distribution,
+        partitioning.numPartitions,
+        partitioning.expressions,
+        hashKeyPositions,
+        otherDistribution,
+        otherPartitioning.numPartitions,
+        otherPartitioning.expressions,
+        otherHashSpec.hashKeyPositions)
+    case ShuffleSpecCollection(specs) =>
+      specs.exists(isCompatibleWith)
+    case _ =>
+      false
+  }
+
+  override def canCreatePartitioning: Boolean = {
+    if 
(SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
+      distribution.areAllClusterKeysMatched(partitioning.expressions)
+    } else {
+      true
+    }
+  }
+
+  override def createPartitioning(clustering: Seq[Expression]): Partitioning = 
{
+    val exprs = hashKeyPositions.map(v => clustering(v.head))
+    NullAwareHashPartitioning(exprs, partitioning.numPartitions)
   }
 
   override def numPartitions: Int = partitioning.numPartitions
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index c34b52e15dbc..8ab725350448 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -967,6 +967,20 @@ object SQLConf {
     .checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be 
positive")
     .createWithDefault(200)
 
+  val SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED =
+    buildConf("spark.sql.shuffle.spreadNullJoinKeys.enabled")
+      .doc("When true, Spark may spread rows with NULL equi-join keys across 
shuffle partitions " +
+        "for shuffled LEFT, RIGHT, and FULL OUTER equi-joins on nullable keys 
to reduce " +
+        "shuffle skew. Null-aware join output partitioning does not satisfy a 
strict " +
+        "ClusteredDistribution, so downstream grouping, windowing, or 
equi-joins may require " +
+        "an extra shuffle. If one input is already hash partitioned, only the 
other input may " +
+        "be reshuffled into the null-aware layout, so the pre-shuffled input 
can keep its NULL " +
+        "skew.")
+      .version("4.1.0")
+      .withBindingPolicy(ConfigBindingPolicy.SESSION)
+      .booleanConf
+      .createWithDefault(false)
+
   val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED =
     buildConf("spark.sql.shuffle.orderIndependentChecksum.enabled")
       .doc("Whether to calculate order independent checksum for the shuffle 
data or not. If " +
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
index 85d285aa76c0..cb5d77d44512 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
@@ -453,6 +453,66 @@ class ShuffleSpecSuite extends SparkFunSuite with 
SQLHelper {
     )
   }
 
+  test("compatibility: NullAwareHashShuffleSpec") {
+    val spreadAB = ClusteredDistribution(Seq($"a", $"b"), 
allowNullKeySpreading = true)
+    val spreadCD = ClusteredDistribution(Seq($"c", $"d"), 
allowNullKeySpreading = true)
+    val regularAB = ClusteredDistribution(Seq($"a", $"b"))
+
+    val nullAwareAB = NullAwareHashShuffleSpec(
+      NullAwareHashPartitioning(Seq($"a", $"b"), 10), spreadAB)
+    val nullAwareCD = NullAwareHashShuffleSpec(
+      NullAwareHashPartitioning(Seq($"c", $"d"), 10), spreadCD)
+    val regularABSpec = HashShuffleSpec(
+      HashPartitioning(Seq($"a", $"b"), 10), regularAB)
+    val spreadABHashSpec = HashShuffleSpec(
+      HashPartitioning(Seq($"a", $"b"), 10), spreadAB)
+
+    checkCompatible(nullAwareAB, nullAwareCD, expected = true)
+    checkCompatible(nullAwareAB, SinglePartitionShuffleSpec, expected = false)
+    checkCompatible(
+      NullAwareHashShuffleSpec(NullAwareHashPartitioning(Seq($"a", $"b"), 1), 
spreadAB),
+      SinglePartitionShuffleSpec,
+      expected = true)
+    checkCompatible(nullAwareAB, regularABSpec, expected = false)
+    checkCompatible(nullAwareAB, spreadABHashSpec, expected = true)
+    checkCompatible(spreadABHashSpec, nullAwareAB, expected = true)
+  }
+
+  test("canCreatePartitioning: NullAwareHashShuffleSpec") {
+    val spreadDistribution =
+      ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)
+    val partialSpec = NullAwareHashShuffleSpec(
+      NullAwareHashPartitioning(Seq($"a"), 10), spreadDistribution)
+    val fullSpec = NullAwareHashShuffleSpec(
+      NullAwareHashPartitioning(Seq($"a", $"b"), 10), spreadDistribution)
+
+    withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> 
"false") {
+      assert(partialSpec.canCreatePartitioning)
+    }
+    withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> 
"true") {
+      assert(!partialSpec.canCreatePartitioning)
+      assert(fullSpec.canCreatePartitioning)
+    }
+  }
+
+  test("createPartitioning: NullAwareHashShuffleSpec") {
+    checkCreatePartitioning(
+      NullAwareHashShuffleSpec(
+        NullAwareHashPartitioning(Seq($"a"), 10),
+        ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)),
+      ClusteredDistribution(Seq($"c", $"d"), allowNullKeySpreading = true),
+      NullAwareHashPartitioning(Seq($"c"), 10)
+    )
+
+    checkCreatePartitioning(
+      HashShuffleSpec(
+        HashPartitioning(Seq($"a"), 10),
+        ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)),
+      ClusteredDistribution(Seq($"c", $"d"), allowNullKeySpreading = true),
+      NullAwareHashPartitioning(Seq($"c"), 10)
+    )
+  }
+
   test("createPartitioning: other specs") {
     val distribution = ClusteredDistribution(Seq($"a", $"b"))
     checkCreatePartitioning(SinglePartitionShuffleSpec,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
index eba0346a94bd..bff86983961c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, 
CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning, 
RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, 
CoalescedHashPartitioning, CoalescedNullAwareHashPartitioning, 
HashPartitioning, NullAwareHashPartitioning, Partitioning, RangePartitioning, 
RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.trees.CurrentOrigin
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, 
ShuffleExchangeLike}
@@ -83,6 +83,13 @@ case class AQEShuffleReadExec private(
               throw SparkException.internalError(s"Unexpected 
ShufflePartitionSpec: $unexpected")
           }
           CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, 
partitions))
+        case h: NullAwareHashPartitioning =>
+          val partitions = partitionSpecs.map {
+            case CoalescedPartitionSpec(start, end, _) => 
CoalescedBoundary(start, end)
+            case unexpected =>
+              throw SparkException.internalError(s"Unexpected 
ShufflePartitionSpec: $unexpected")
+          }
+          
CurrentOrigin.withOrigin(h.origin)(CoalescedNullAwareHashPartitioning(h, 
partitions))
         case r: RangePartitioning =>
           CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = 
partitionSpecs.length))
         // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 744438422916..114f221c52f6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -30,7 +30,9 @@ import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, 
ShuffleWriteProcessor}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
UnsafeProjection, UnsafeRow, UnsafeRowChecksum}
+import org.apache.spark.sql.catalyst.expressions.{
+  Attribute, BoundReference, CollationAwareMurmur3Hash, Literal, Pmod, 
UnsafeProjection,
+  UnsafeRow, UnsafeRowChecksum}
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
@@ -349,6 +351,10 @@ object ShuffleExchangeExec {
         // For HashPartitioning, the partitioning key is already a valid 
partition ID, as we use
         // `HashPartitioning.partitionIdExpression` to produce partitioning 
key.
         new PartitionIdPassthrough(n)
+      case NullAwareHashPartitioning(_, n) =>
+        // The null-aware extractor below produces partition IDs directly:
+        // Pmod(hash, n) for non-NULL keys, and a round-robin counter for NULL 
keys.
+        new PartitionIdPassthrough(n)
       case ShufflePartitionIdPassThrough(_, n) =>
         // For ShufflePartitionIdPassThrough, the DirectShufflePartitionID 
expression directly
         // produces partition IDs, so we use PartitionIdPassthrough to pass 
them through directly.
@@ -403,6 +409,32 @@ object ShuffleExchangeExec {
       case h: HashPartitioning =>
         val projection = UnsafeProjection.create(h.partitionIdExpression :: 
Nil, outputAttributes)
         row => projection(row).getInt(0)
+      case h: NullAwareHashPartitioning =>
+        // Non-NULL keys must produce the same partition id as
+        // HashPartitioning.partitionIdExpression so opted-in HashShuffleSpec 
and
+        // NullAwareHashShuffleSpec inputs stay aligned.
+        val joinKeyProjection = UnsafeProjection.create(h.expressions, 
outputAttributes)
+        val boundJoinKeys = h.expressions.zipWithIndex.map { case (expr, 
index) =>
+          BoundReference(index, expr.dataType, expr.nullable)
+        }
+        val partitionIdExpression = Pmod(
+          new CollationAwareMurmur3Hash(boundJoinKeys),
+          Literal(h.numPartitions))
+        val partitionIdProjection = 
UnsafeProjection.create(partitionIdExpression :: Nil)
+        var nullKeyPartition =
+          new 
XORShiftRandom(TaskContext.get().partitionId()).nextInt(h.numPartitions)
+        row => {
+          val joinKeys = joinKeyProjection(row)
+          if (joinKeys.anyNull()) {
+            // NULL join keys cannot match under ordinary equi-join semantics. 
Spread them
+            // round-robin within each map task so identical rows do not 
collapse to one reducer.
+            val partition = nullKeyPartition
+            nullKeyPartition = (nullKeyPartition + 1) % h.numPartitions
+            partition
+          } else {
+            partitionIdProjection(joinKeys).getInt(0)
+          }
+        }
       case RangePartitioning(sortingExpressions, _) =>
         val projection = 
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
         row => projection(row)
@@ -419,9 +451,14 @@ object ShuffleExchangeExec {
 
     val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
       newPartitioning.numPartitions > 1
+    val isNullAwareHashPartitioning =
+      newPartitioning.isInstanceOf[NullAwareHashPartitioning] &&
+        newPartitioning.numPartitions > 1
+    val needsDeterministicLocalSort =
+      (isRoundRobin || isNullAwareHashPartitioning) && 
SQLConf.get.sortBeforeRepartition
 
     val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
-      // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning 
is deterministic,
+      // [SPARK-23207] Have to make sure stateful row-to-partition assignment 
is deterministic,
       // otherwise a retry task may output different rows and thus lead to 
data loss.
       //
       // Currently we following the most straight-forward way that perform a 
local sort before
@@ -429,7 +466,7 @@ object ShuffleExchangeExec {
       //
       // Note that we don't perform local sort if the new partitioning has 
only 1 partition, under
       // that case all output rows go to the same partition.
-      val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
+      val newRdd = if (needsDeterministicLocalSort) {
         rdd.mapPartitionsInternal { iter =>
           val recordComparatorSupplier = new Supplier[RecordComparator] {
             override def get: RecordComparator = new RecordBinaryComparator()
@@ -468,7 +505,9 @@ object ShuffleExchangeExec {
       }
 
       // round-robin function is order sensitive if we don't sort the input.
-      val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
+      // Stateful partition assignment is order-sensitive when it depends on 
row visitation order.
+      val isOrderSensitive =
+        (isRoundRobin || isNullAwareHashPartitioning) && 
!SQLConf.get.sortBeforeRepartition
       if (needToCopyObjectsBeforeShuffle(part)) {
         newRdd.mapPartitionsWithIndexInternal((_, iter) => {
           val getPartitionKey = getPartitionKeyExtractor()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
index 3fb968bfea7a..179f88c99af6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, 
InnerLike, LeftExistence, LeftOuter, LeftSingle, RightOuter}
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, 
UnspecifiedDistribution}
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Holds common logic for join operators by shuffling two child relations
@@ -28,6 +29,24 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Dist
 trait ShuffledJoin extends JoinCodegenSupport {
   def isSkewJoin: Boolean
 
+  private lazy val canSpreadNullJoinKeys: Boolean = {
+    // Only NULL keys on the preserved side can create this skew: they must be 
emitted, but
+    // cannot satisfy ordinary equi-join predicates. Non-preserved NULL-keyed 
rows are filtered
+    // out by `=` and never emitted, so their reducer placement does not 
matter here.
+    //
+    // Null-safe equality usually rewrites to non-null shuffle keys. The 
NullType corner can still
+    // produce NULL shuffle keys, but shuffled join execution already treats 
those rows as
+    // unmatched, so spreading them does not change the result.
+    val preservedSideHasNullableKeys = joinType match {
+      case LeftOuter => leftKeys.exists(_.nullable)
+      case RightOuter => rightKeys.exists(_.nullable)
+      case FullOuter => leftKeys.exists(_.nullable) || 
rightKeys.exists(_.nullable)
+      case _ => false
+    }
+    conf.getConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED) &&
+      preservedSideHasNullableKeys
+  }
+
   override def nodeName: String = {
     if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName
   }
@@ -39,6 +58,9 @@ trait ShuffledJoin extends JoinCodegenSupport {
       // We re-arrange the shuffle partitions to deal with skew join, and the 
new children
       // partitioning doesn't satisfy `ClusteredDistribution`.
       UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+    } else if (canSpreadNullJoinKeys) {
+      ClusteredDistribution(leftKeys, allowNullKeySpreading = true) ::
+        ClusteredDistribution(rightKeys, allowNullKeySpreading = true) :: Nil
     } else {
       ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: 
Nil
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
index 0273a5d6dd49..c1741cac8ad3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
@@ -67,7 +67,7 @@ abstract class DistributionAndOrderingSuiteBase
   protected def resolveDistribution[T <: QueryPlan[T]](
       distribution: physical.Distribution,
       plan: QueryPlan[T]): physical.Distribution = distribution match {
-    case physical.ClusteredDistribution(clustering, numPartitions, _) =>
+    case physical.ClusteredDistribution(clustering, numPartitions, _, _) =>
       physical.ClusteredDistribution(clustering.map(resolveAttrs(_, plan)), 
numPartitions)
     case physical.OrderedDistribution(ordering) =>
       physical.OrderedDistribution(ordering.map(resolveAttrs(_, 
plan).asInstanceOf[SortOrder]))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 2a0ab52c3693..711f6dbdcdb1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -233,7 +233,7 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase with
     }.head
 
     resolveDistribution(distribution, relation) match {
-      case physical.ClusteredDistribution(clustering, _, _) =>
+      case physical.ClusteredDistribution(clustering, _, _, _) =>
         assert(relation.keyGroupedPartitioning.isDefined &&
           relation.keyGroupedPartitioning.get == clustering)
       case _ =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 554cf5111bea..b7798b0bde5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.execution
 
 import scala.util.Random
 
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
 import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
IdentityBroadcastMode, SinglePartition}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
IdentityBroadcastMode, NullAwareHashPartitioning, SinglePartition}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
Exchange, ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
 import org.apache.spark.sql.internal.SQLConf
@@ -59,6 +59,39 @@ class ExchangeSuite extends SharedSparkSession {
     )
   }
 
+  test("null-aware hash shuffle spreads identical NULL keys from one mapper") {
+    val input = 
Seq.fill(64)(Tuple1(null.asInstanceOf[Integer])).toDF("k").coalesce(1)
+    val plan = input.queryExecution.executedPlan
+    val exchange = ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 
4), plan)
+    val partitionSizes = exchange.execute().collectPartitions().map(_.length)
+
+    assert(partitionSizes.sorted === Array(16, 16, 16, 16))
+  }
+
+  test("null-aware hash shuffle preserves retry determinism with local 
sorting") {
+    withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") {
+      val input = spark.range(64).repartition(4).selectExpr("CAST(NULL AS INT) 
AS k")
+      val plan = input.queryExecution.executedPlan
+      val exchange = 
ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 4), plan)
+
+      assert(plan.execute().outputDeterministicLevel == 
DeterministicLevel.UNORDERED)
+      assert(exchange.shuffleDependency.rdd.outputDeterministicLevel !=
+        DeterministicLevel.INDETERMINATE)
+    }
+  }
+
+  test("null-aware hash shuffle marks unsorted repartitioning as 
order-sensitive") {
+    withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "false") {
+      val input = spark.range(64).repartition(4).selectExpr("CAST(NULL AS INT) 
AS k")
+      val plan = input.queryExecution.executedPlan
+      val exchange = 
ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 4), plan)
+
+      assert(plan.execute().outputDeterministicLevel == 
DeterministicLevel.UNORDERED)
+      assert(exchange.shuffleDependency.rdd.outputDeterministicLevel ==
+        DeterministicLevel.INDETERMINATE)
+    }
+  }
+
   test("BroadcastMode.canonicalized") {
     val mode1 = IdentityBroadcastMode
     val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 50322905f29f..0e7ba599e0fb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference,
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
 import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, JoinHint, 
LocalRelation, LogicalPlan}
+import 
org.apache.spark.sql.catalyst.plans.physical.CoalescedNullAwareHashPartitioning
 import org.apache.spark.sql.classic.Strategy
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
@@ -2089,55 +2090,80 @@ class AdaptiveQueryExecSuite
           |ON CAST(value AS INT) = b
         """.stripMargin)
 
-      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-        // Repartition with no partition num specified.
-        checkBHJ(df.repartition($"b"),
-          // The top shuffle from repartition is optimized out.
-          optimizeOutRepartition = true, probeSideLocalRead = false, 
probeSideCoalescedRead = true)
-
-        // Repartition with default partition num (5 in test env) specified.
-        checkBHJ(df.repartition(5, $"b"),
-          // The top shuffle from repartition is optimized out
-          // The final plan must have 5 partitions, no optimization can be 
made to the probe side.
-          optimizeOutRepartition = true, probeSideLocalRead = false, 
probeSideCoalescedRead = false)
-
-        // Repartition with non-default partition num specified.
-        checkBHJ(df.repartition(4, $"b"),
-          // The top shuffle from repartition is not optimized out
-          optimizeOutRepartition = false, probeSideLocalRead = true, 
probeSideCoalescedRead = true)
+      def checkRepartitionOptimization(
+          df: Dataset[Row],
+          useNullAwarePartitioning: Boolean): Unit = {
+        val optimizeDefaultRepartition = !useNullAwarePartitioning
+        withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+          // Repartition with no partition num specified.
+          checkBHJ(df.repartition($"b"),
+            optimizeOutRepartition = optimizeDefaultRepartition,
+            probeSideLocalRead = useNullAwarePartitioning,
+            probeSideCoalescedRead = !useNullAwarePartitioning)
+
+          // Repartition with default partition num (5 in test env) specified.
+          checkBHJ(df.repartition(5, $"b"),
+            optimizeOutRepartition = optimizeDefaultRepartition,
+            probeSideLocalRead = useNullAwarePartitioning,
+            probeSideCoalescedRead = false)
+
+          // Repartition with non-default partition num specified.
+          checkBHJ(df.repartition(4, $"b"),
+            optimizeOutRepartition = false,
+            probeSideLocalRead = true,
+            probeSideCoalescedRead = true)
+
+          // Repartition by col and project away the partition cols
+          checkBHJ(df.repartition($"b").select($"key"),
+            optimizeOutRepartition = false,
+            probeSideLocalRead = true,
+            probeSideCoalescedRead = true)
+        }
 
-        // Repartition by col and project away the partition cols
-        checkBHJ(df.repartition($"b").select($"key"),
-          // The top shuffle from repartition is not optimized out
-          optimizeOutRepartition = false, probeSideLocalRead = true, 
probeSideCoalescedRead = true)
+        // Force skew join
+        withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+          SQLConf.SKEW_JOIN_ENABLED.key -> "true",
+          SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1",
+          SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
+          SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
+          // Repartition with no partition num specified.
+          checkSMJ(df.repartition($"b"),
+            optimizeOutRepartition = optimizeDefaultRepartition,
+            optimizeSkewJoin = useNullAwarePartitioning,
+            coalescedRead = !useNullAwarePartitioning)
+
+          // Repartition with default partition num (5 in test env) specified.
+          checkSMJ(df.repartition(5, $"b"),
+            optimizeOutRepartition = optimizeDefaultRepartition,
+            optimizeSkewJoin = useNullAwarePartitioning,
+            coalescedRead = false)
+
+          // Repartition with non-default partition num specified.
+          checkSMJ(df.repartition(4, $"b"),
+            optimizeOutRepartition = false, optimizeSkewJoin = true, 
coalescedRead = false)
+
+          // Repartition by col and project away the partition cols
+          checkSMJ(df.repartition($"b").select($"key"),
+            optimizeOutRepartition = false, optimizeSkewJoin = true, 
coalescedRead = false)
+        }
       }
 
-      // Force skew join
-      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-        SQLConf.SKEW_JOIN_ENABLED.key -> "true",
-        SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1",
-        SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
-        SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
-        // Repartition with no partition num specified.
-        checkSMJ(df.repartition($"b"),
-          // The top shuffle from repartition is optimized out.
-          optimizeOutRepartition = true, optimizeSkewJoin = false, 
coalescedRead = true)
-
-        // Repartition with default partition num (5 in test env) specified.
-        checkSMJ(df.repartition(5, $"b"),
-          // The top shuffle from repartition is optimized out.
-          // The final plan must have 5 partitions, can't do coalesced read.
-          optimizeOutRepartition = true, optimizeSkewJoin = false, 
coalescedRead = false)
-
-        // Repartition with non-default partition num specified.
-        checkSMJ(df.repartition(4, $"b"),
-          // The top shuffle from repartition is not optimized out.
-          optimizeOutRepartition = false, optimizeSkewJoin = true, 
coalescedRead = false)
-
-        // Repartition by col and project away the partition cols
-        checkSMJ(df.repartition($"b").select($"key"),
-          // The top shuffle from repartition is not optimized out.
-          optimizeOutRepartition = false, optimizeSkewJoin = true, 
coalescedRead = false)
+      checkRepartitionOptimization(df, useNullAwarePartitioning = false)
+      withSQLConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") 
{
+        // Null-aware join output partitioning is not equivalent to ordinary 
hash repartitioning.
+        val nullablePreservedSideDf = sql(
+          """
+            |SELECT * FROM (
+            |  SELECT * FROM testData WHERE key = 1
+            |)
+            |RIGHT OUTER JOIN (
+            |  SELECT a, b FROM testData2
+            |  UNION ALL
+            |  SELECT CAST(NULL AS INT) AS a, CAST(NULL AS INT) AS b
+            |)
+            |ON CAST(value AS INT) = b
+          """.stripMargin)
+        checkRepartitionOptimization(nullablePreservedSideDf, 
useNullAwarePartitioning = true)
       }
     }
   }
@@ -2604,6 +2630,39 @@ class AdaptiveQueryExecSuite
     }
   }
 
+  test("AQE preserves coalesced null-aware partitioning for outer equi-join") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+      SQLConf.SHUFFLE_PARTITIONS.key -> "8",
+      SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true",
+      SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576") {
+      val nullableLeft = Seq(
+        (Integer.valueOf(1), "left-1"),
+        (null.asInstanceOf[Integer], "left-null-1"),
+        (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+      val nullableRight = Seq(
+        (Integer.valueOf(1), "right-1"),
+        (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+      val df = nullableLeft.join(
+        nullableRight, nullableLeft("k") === nullableRight("k"), "left_outer")
+
+      checkAnswer(df, Seq(
+        Row(1, "left-1", 1, "right-1"),
+        Row(null, "left-null-1", null, null),
+        Row(null, "left-null-2", null, null)))
+
+      val coalescedNullAwareReads = collect(df.queryExecution.executedPlan) {
+        case read: AQEShuffleReadExec
+            if read.hasCoalescedPartition &&
+              
read.outputPartitioning.isInstanceOf[CoalescedNullAwareHashPartitioning] =>
+          read
+      }
+      assert(coalescedNullAwareReads.nonEmpty)
+    }
+  }
+
   test("SPARK-35794: Allow custom plugin for cost evaluator") {
     CostEvaluator.instantiate(
       classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, 
spark.sparkContext.getConf)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 73e739e261b7..2deb452c3a09 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -17,15 +17,16 @@
 
 package org.apache.spark.sql.execution.joins
 
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Column, Row}
 import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, 
LessThan}
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
NullAwareHashPartitioning}
 import org.apache.spark.sql.classic.DataFrame
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.exchange.EnsureRequirements
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ShuffleExchangeExec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{SharedSparkSession, SQLTestData}
 import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
@@ -36,6 +37,18 @@ class OuterJoinSuite extends SharedSparkSession with 
SQLTestData {
 
   private val EnsureRequirements = new EnsureRequirements()
 
+  private def extractJoinParts(
+      left: DataFrame,
+      right: DataFrame,
+      condition: Column): ExtractEquiJoinKeys.ReturnType = {
+    val analyzedJoin = left.join(right, condition, "inner")
+      .queryExecution.analyzed
+      .collectFirst { case join: Join => join }
+      .getOrElse(fail("Failed to build analyzed equi-join"))
+    ExtractEquiJoinKeys.unapply(analyzedJoin)
+      .getOrElse(fail("Failed to extract equi-join keys"))
+  }
+
   private lazy val left = spark.createDataFrame(
     sparkContext.parallelize(Seq(
       Row(1, 2.0),
@@ -345,4 +358,305 @@ class OuterJoinSuite extends SharedSparkSession with 
SQLTestData {
     val df2 = join("SHUFFLE_MERGE(t1)")
     checkAnswer(df1, identity, df2.collect().toSeq)
   }
+
+  test("ordinary outer equi-join spreads NULL keys in shuffle partitioning") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), "left-1"),
+      (null.asInstanceOf[Integer], "left-null-1"),
+      (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), "right-1"),
+      (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === 
nullableRight("k"))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+          nullableLeft.queryExecution.sparkPlan, 
nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+      checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: 
SparkPlan) =>
+        EnsureRequirements.apply(
+          SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, 
left, right)),
+        Seq(
+          Row(1, "left-1", 1, "right-1"),
+          Row(null, "left-null-1", null, null),
+          Row(null, "left-null-2", null, null)),
+        sortAnswers = true)
+    }
+  }
+
+  test("ordinary outer equi-join keeps hash partitioning when null-aware 
shuffle is disabled") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), "left-1"),
+      (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), "right-1"),
+      (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === 
nullableRight("k"))
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "4") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+          nullableLeft.queryExecution.sparkPlan, 
nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[HashPartitioning]))
+    }
+  }
+
+  test("ordinary outer equi-join keeps hash partitioning for non-nullable join 
keys") {
+    val nonNullableLeft = spark.range(3).toDF("k")
+    val nonNullableRight = spark.range(3).toDF("k")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(
+        nonNullableLeft,
+        nonNullableRight,
+        nonNullableLeft("k") === nonNullableRight("k"))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+          nonNullableLeft.queryExecution.sparkPlan, 
nonNullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[HashPartitioning]))
+    }
+  }
+
+  test("ordinary right outer equi-join spreads NULL keys in shuffle 
partitioning") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), "left-1"),
+      (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), "right-1"),
+      (null.asInstanceOf[Integer], "right-null-1"),
+      (null.asInstanceOf[Integer], "right-null-2")).toDF("k", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === 
nullableRight("k"))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, RightOuter, boundCondition,
+          nullableLeft.queryExecution.sparkPlan, 
nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+      checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: 
SparkPlan) =>
+        EnsureRequirements.apply(
+          SortMergeJoinExec(leftKeys, rightKeys, RightOuter, boundCondition, 
left, right)),
+        Seq(
+          Row(1, "left-1", 1, "right-1"),
+          Row(null, null, null, "right-null-1"),
+          Row(null, null, null, "right-null-2")),
+        sortAnswers = true)
+    }
+  }
+
+  test("ordinary full outer equi-join keeps NULL keys unmatched while 
spreading them") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), "left-1"),
+      (null.asInstanceOf[Integer], "left-null-1"),
+      (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), "right-1"),
+      (null.asInstanceOf[Integer], "right-null-1"),
+      (null.asInstanceOf[Integer], "right-null-2")).toDF("k", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === 
nullableRight("k"))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, FullOuter, boundCondition,
+          nullableLeft.queryExecution.sparkPlan, 
nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+      checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: 
SparkPlan) =>
+        EnsureRequirements.apply(
+          SortMergeJoinExec(leftKeys, rightKeys, FullOuter, boundCondition, 
left, right)),
+        Seq(
+          Row(1, "left-1", 1, "right-1"),
+          Row(null, "left-null-1", null, null),
+          Row(null, "left-null-2", null, null),
+          Row(null, null, null, "right-null-1"),
+          Row(null, null, null, "right-null-2")),
+        sortAnswers = true)
+    }
+  }
+
+  test("ordinary outer equi-join preserves null-aware shuffle beside existing 
hash partitioning") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), "left-1"),
+      (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), "right-1"),
+      (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === 
nullableRight("k"))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val existingLeftShuffle = ShuffleExchangeExec(
+        HashPartitioning(leftKeys, 4),
+        nullableLeft.queryExecution.sparkPlan)
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+          existingLeftShuffle, nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+
+      assert(partitionings.size == 2)
+      assert(partitionings.count(_.isInstanceOf[HashPartitioning]) == 1)
+      assert(partitionings.count(_.isInstanceOf[NullAwareHashPartitioning]) == 
1)
+    }
+  }
+
+  test("mixed ordinary and null-safe outer equi-join can use null-aware 
shuffle partitioning") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), null.asInstanceOf[Integer], "left-match"),
+      (Integer.valueOf(2), null.asInstanceOf[Integer], "left-no-match"))
+      .toDF("k1", "k2", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), null.asInstanceOf[Integer], "right-match"),
+      (Integer.valueOf(2), Integer.valueOf(3), "right-no-match"))
+      .toDF("k1", "k2", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(
+        nullableLeft,
+        nullableRight,
+        nullableLeft("k1") === nullableRight("k1") &&
+          nullableLeft("k2").eqNullSafe(nullableRight("k2")))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+          nullableLeft.queryExecution.sparkPlan, 
nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+      checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: 
SparkPlan) =>
+        EnsureRequirements.apply(
+          SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, 
left, right)),
+        Seq(
+          Row(1, null, "left-match", 1, null, "right-match"),
+          Row(2, null, "left-no-match", null, null, null)),
+        sortAnswers = true)
+    }
+  }
+
+  test("null-safe outer equi-join keeps hash partitioning for non-null shuffle 
keys") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), "left-1"),
+      (null.asInstanceOf[Integer], "left-null"))
+      .toDF("k", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), "right-1"),
+      (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(
+        nullableLeft,
+        nullableRight,
+        nullableLeft("k").eqNullSafe(nullableRight("k")))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+          nullableLeft.queryExecution.sparkPlan, 
nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[HashPartitioning]))
+    }
+  }
+
+  test("ordinary outer equi-join spreads NULL keys for shuffled hash join") {
+    val nullableLeft = Seq(
+      (Integer.valueOf(1), "left-1"),
+      (null.asInstanceOf[Integer], "left-null-1"),
+      (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv")
+    val nullableRight = Seq(
+      (Integer.valueOf(1), "right-1"),
+      (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === 
nullableRight("k"))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        ShuffledHashJoinExec(leftKeys, rightKeys, LeftOuter, BuildRight, 
boundCondition,
+          nullableLeft.queryExecution.sparkPlan, 
nullableRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+      checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: 
SparkPlan) =>
+        EnsureRequirements.apply(
+          ShuffledHashJoinExec(
+            leftKeys, rightKeys, LeftOuter, BuildRight, boundCondition, left, 
right)),
+        Seq(
+          Row(1, "left-1", 1, "right-1"),
+          Row(null, "left-null-1", null, null),
+          Row(null, "left-null-2", null, null)),
+        sortAnswers = true)
+    }
+  }
+
+  test("NullType null-safe outer equi-join remains result-safe with null-aware 
shuffle") {
+    val nullTypeLeft = spark.range(2).selectExpr("NULL AS k", "id AS lv")
+    val nullTypeRight = spark.range(1).selectExpr("NULL AS k", "id AS rv")
+    val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =
+      extractJoinParts(
+        nullTypeLeft,
+        nullTypeRight,
+        nullTypeLeft("k").eqNullSafe(nullTypeRight("k")))
+    withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "4",
+        SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") {
+      val plan = EnsureRequirements.apply(
+        SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition,
+          nullTypeLeft.queryExecution.sparkPlan, 
nullTypeRight.queryExecution.sparkPlan))
+      val partitionings = plan.collect {
+        case exchange: ShuffleExchangeExec => exchange.outputPartitioning
+      }
+      assert(partitionings.size == 2)
+      assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning]))
+
+      checkAnswer2(nullTypeLeft, nullTypeRight, (left: SparkPlan, right: 
SparkPlan) =>
+        EnsureRequirements.apply(
+          SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, 
left, right)),
+        Seq(
+          Row(null, 0L, null, null),
+          Row(null, 1L, null, null)),
+        sortAnswers = true)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to