This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new b9168f43 [AURON #1792][FOLLOWUP] Broadcast isNullAwareAntiJoin flag 
(#1866)
b9168f43 is described below

commit b9168f43381c299917ef5090d97cfb236ee6559b
Author: cxzl25 <[email protected]>
AuthorDate: Mon Jan 12 11:25:57 2026 +0800

    [AURON #1792][FOLLOWUP] Broadcast isNullAwareAntiJoin flag (#1866)
    
    # Which issue does this PR close?
    
    Closes #1792
    
    # Rationale for this change
    
    # What changes are included in this PR?
    
    # Are there any user-facing changes?
    
    # How was this patch tested?
    Add UT
---
 native-engine/auron-serde/proto/auron.proto        |  1 +
 native-engine/auron-serde/src/from_proto.rs        |  3 +
 .../src/broadcast_join_exec.rs                     |  5 ++
 .../src/joins/bhj/semi_join.rs                     |  6 +-
 .../datafusion-ext-plans/src/joins/mod.rs          |  1 +
 .../datafusion-ext-plans/src/joins/test.rs         | 20 ++----
 .../src/sort_merge_join_exec.rs                    |  1 +
 .../org/apache/spark/sql/auron/ShimsImpl.scala     |  6 +-
 .../joins/auron/plan/NativeBroadcastJoinExec.scala |  6 +-
 .../scala/org/apache/auron/AuronQuerySuite.scala   | 81 ++++++++++++++++++++++
 .../apache/spark/sql/auron/AuronConverters.scala   | 17 +++--
 .../scala/org/apache/spark/sql/auron/Shims.scala   |  3 +-
 .../auron/plan/NativeBroadcastJoinBase.scala       |  4 +-
 13 files changed, 128 insertions(+), 26 deletions(-)

diff --git a/native-engine/auron-serde/proto/auron.proto 
b/native-engine/auron-serde/proto/auron.proto
index 29e9f113..788be352 100644
--- a/native-engine/auron-serde/proto/auron.proto
+++ b/native-engine/auron-serde/proto/auron.proto
@@ -468,6 +468,7 @@ message BroadcastJoinExecNode {
   JoinType join_type = 5;
   JoinSide broadcast_side = 6;
   string cached_build_hash_map_id = 7;
+  bool is_null_aware_anti_join = 8;
 }
 
 message RenameColumnsExecNode {
diff --git a/native-engine/auron-serde/src/from_proto.rs 
b/native-engine/auron-serde/src/from_proto.rs
index 09c700db..9237a6c2 100644
--- a/native-engine/auron-serde/src/from_proto.rs
+++ b/native-engine/auron-serde/src/from_proto.rs
@@ -219,6 +219,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for 
&protobuf::PhysicalPlanNode {
                         .map_err(|_| proto_error("invalid BuildSide"))?,
                     false,
                     None,
+                    false,
                 )?))
             }
             PhysicalPlanType::SortMergeJoin(sort_merge_join) => {
@@ -354,6 +355,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for 
&protobuf::PhysicalPlanNode {
                     .expect("invalid BroadcastSide");
 
                 let cached_build_hash_map_id = 
broadcast_join.cached_build_hash_map_id.clone();
+                let is_null_aware_anti_join = 
broadcast_join.is_null_aware_anti_join;
 
                 Ok(Arc::new(BroadcastJoinExec::try_new(
                     schema,
@@ -368,6 +370,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for 
&protobuf::PhysicalPlanNode {
                         .map_err(|_| proto_error("invalid BroadcastSide"))?,
                     true,
                     Some(cached_build_hash_map_id),
+                    is_null_aware_anti_join,
                 )?))
             }
             PhysicalPlanType::Union(union) => {
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs 
b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
index 276f5e09..fef3397b 100644
--- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
+++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
@@ -88,6 +88,7 @@ pub struct BroadcastJoinExec {
     schema: SchemaRef,
     is_built: bool, // true for BroadcastHashJoin, false for ShuffledHashJoin
     cached_build_hash_map_id: Option<String>,
+    is_null_aware_anti_join: bool,
     metrics: ExecutionPlanMetricsSet,
     props: OnceCell<PlanProperties>,
 }
@@ -102,6 +103,7 @@ impl BroadcastJoinExec {
         broadcast_side: JoinSide,
         is_built: bool,
         cached_build_hash_map_id: Option<String>,
+        is_null_aware_anti_join: bool,
     ) -> Result<Self> {
         Ok(Self {
             left,
@@ -112,6 +114,7 @@ impl BroadcastJoinExec {
             schema,
             is_built,
             cached_build_hash_map_id,
+            is_null_aware_anti_join,
             metrics: ExecutionPlanMetricsSet::new(),
             props: OnceCell::new(),
         })
@@ -176,6 +179,7 @@ impl BroadcastJoinExec {
             sort_options: vec![SortOptions::default(); self.on.len()],
             projection,
             key_data_types,
+            is_null_aware_anti_join: self.is_null_aware_anti_join,
         })
     }
 
@@ -279,6 +283,7 @@ impl ExecutionPlan for BroadcastJoinExec {
             self.broadcast_side,
             self.is_built,
             None,
+            self.is_null_aware_anti_join,
         )?))
     }
 
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs 
b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
index 41ebcf6f..1018b72a 100644
--- a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
@@ -193,7 +193,11 @@ impl<const P: JoinerParams> Joiner for SemiJoiner<P> {
                 .as_ref()
                 .map(|nb| nb.is_valid(row_idx))
                 .unwrap_or(true);
-            if P.mode == Anti && P.probe_is_join_side && !key_is_valid {
+            if P.mode == Anti
+                && P.probe_is_join_side
+                && !key_is_valid
+                && self.join_params.is_null_aware_anti_join
+            {
                 probed_joined.set(row_idx, true);
                 continue;
             }
diff --git a/native-engine/datafusion-ext-plans/src/joins/mod.rs 
b/native-engine/datafusion-ext-plans/src/joins/mod.rs
index 5f8ae997..6ccc4086 100644
--- a/native-engine/datafusion-ext-plans/src/joins/mod.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/mod.rs
@@ -46,6 +46,7 @@ pub struct JoinParams {
     pub sort_options: Vec<SortOptions>,
     pub projection: JoinProjection,
     pub batch_size: usize,
+    pub is_null_aware_anti_join: bool,
 }
 
 #[derive(Debug, Clone)]
diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs 
b/native-engine/datafusion-ext-plans/src/joins/test.rs
index 671ecd73..2c50cabb 100644
--- a/native-engine/datafusion-ext-plans/src/joins/test.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/test.rs
@@ -219,6 +219,7 @@ mod tests {
                     JoinSide::Right,
                     true,
                     None,
+                    false,
                 )?)
             }
             BHJRightProbed => {
@@ -235,6 +236,7 @@ mod tests {
                     JoinSide::Left,
                     true,
                     None,
+                    false,
                 )?)
             }
             SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new(
@@ -246,6 +248,7 @@ mod tests {
                 JoinSide::Right,
                 false,
                 None,
+                false,
             )?),
             SHJRightProbed => Arc::new(BroadcastJoinExec::try_new(
                 schema,
@@ -256,6 +259,7 @@ mod tests {
                 JoinSide::Left,
                 false,
                 None,
+                false,
             )?),
         };
         let columns = columns(&join.schema());
@@ -617,21 +621,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?),
         )];
 
-        for test_type in [BHJLeftProbed, SHJLeftProbed] {
-            let (_, batches) =
-                join_collect(test_type, left.clone(), right.clone(), 
on.clone(), LeftAnti).await?;
-            let expected = vec![
-                "+----+----+----+",
-                "| a1 | b1 | c1 |",
-                "+----+----+----+",
-                "|    | 6  | 9  |",
-                "| 5  | 8  | 11 |",
-                "+----+----+----+",
-            ];
-            assert_batches_sorted_eq!(expected, &batches);
-        }
-
-        for test_type in [SMJ, BHJRightProbed, SHJRightProbed] {
+        for test_type in ALL_TEST_TYPE {
             let (_, batches) =
                 join_collect(test_type, left.clone(), right.clone(), 
on.clone(), LeftAnti).await?;
             let expected = vec![
diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs 
b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
index 78eda5b6..91496c49 100644
--- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
+++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
@@ -128,6 +128,7 @@ impl SortMergeJoinExec {
             sort_options: self.sort_options.clone(),
             projection,
             batch_size: batch_size(),
+            is_null_aware_anti_join: false,
         })
     }
 
diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
index 3acbbed9..9cecb869 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -229,7 +229,8 @@ class ShimsImpl extends Shims with Logging {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      broadcastSide: BroadcastSide): NativeBroadcastJoinBase =
+      broadcastSide: BroadcastSide,
+      isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase =
     NativeBroadcastJoinExec(
       left,
       right,
@@ -237,7 +238,8 @@ class ShimsImpl extends Shims with Logging {
       leftKeys,
       rightKeys,
       joinType,
-      broadcastSide)
+      broadcastSide,
+      isNullAwareAntiJoin)
 
   override def createNativeSortMergeJoinExec(
       left: SparkPlan,
diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
index d0c2cea8..9ac6e893 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
@@ -35,7 +35,8 @@ case class NativeBroadcastJoinExec(
     override val leftKeys: Seq[Expression],
     override val rightKeys: Seq[Expression],
     override val joinType: JoinType,
-    broadcastSide: BroadcastSide)
+    broadcastSide: BroadcastSide,
+    isNullAwareAntiJoin: Boolean)
     extends NativeBroadcastJoinBase(
       left,
       right,
@@ -43,7 +44,8 @@ case class NativeBroadcastJoinExec(
       leftKeys,
       rightKeys,
       joinType,
-      broadcastSide)
+      broadcastSide,
+      isNullAwareAntiJoin)
     with HashJoin {
 
   override val condition: Option[Expression] = None
diff --git 
a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
index 3a2cc9cf..8fe2a3e3 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
@@ -581,4 +581,85 @@ class AuronQuerySuite extends AuronQueryTest with 
BaseAuronSQLSuite with AuronSQ
       }
     }
   }
+
+  test("standard LEFT ANTI JOIN includes NULL keys") {
+    // This test verifies that standard LEFT ANTI JOIN correctly includes NULL 
keys
+    // NULL keys should be in the result because NULL never matches anything
+    withTable("left_table", "right_table") {
+      sql("""
+            |CREATE TABLE left_table using parquet AS
+            |SELECT * FROM VALUES
+            |  (1, 2.0),
+            |  (1, 2.0),
+            |  (2, 1.0),
+            |  (2, 1.0),
+            |  (3, 3.0),
+            |  (null, null),
+            |  (null, 5.0),
+            |  (6, null)
+            |AS t(a, b)
+            |""".stripMargin)
+
+      sql("""
+            |CREATE TABLE right_table using parquet AS
+            |SELECT * FROM VALUES
+            |  (2, 3.0),
+            |  (2, 3.0),
+            |  (3, 2.0),
+            |  (4, 1.0),
+            |  (null, null),
+            |  (null, 5.0),
+            |  (6, null)
+            |AS t(c, d)
+            |""".stripMargin)
+
+      // Standard LEFT ANTI JOIN should include rows with NULL keys
+      // Expected: (1, 2.0), (1, 2.0), (null, null), (null, 5.0)
+      checkSparkAnswer(
+        "SELECT * FROM left_table LEFT ANTI JOIN right_table ON left_table.a = 
right_table.c")
+    }
+  }
+
+  test("left join with NOT IN subquery should filter NULL values") {
+    // This test verifies the fix for the NULL handling issue in Anti join.
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      val query =
+        """
+          |WITH t2 AS (
+          |  -- Large table: 100000 rows (0..99999)
+          |  SELECT id AS loan_req_no
+          |  FROM range(0, 100000)
+          |),
+          |t1 AS (
+          |  -- Small table: 10 rows that can match t2
+          |  SELECT * FROM VALUES
+          |    (1, 'A'),
+          |    (2, 'B'),
+          |    (3, 'C'),
+          |    (4, 'D'),
+          |    (5, 'E'),
+          |    (6, 'F'),
+          |    (7, 'G'),
+          |    (8, 'H'),
+          |    (9, 'I'),
+          |    (10,'J')
+          |  AS t1(loan_req_no, partner_code)
+          |),
+          |blk AS (
+          |  SELECT * FROM VALUES
+          |    ('B'),
+          |    ('Z')
+          |  AS blk(code)
+          |)
+          |SELECT
+          |  COUNT(*) AS cnt
+          |FROM t2
+          |LEFT JOIN t1
+          |  ON t1.loan_req_no = t2.loan_req_no
+          |WHERE t1.partner_code NOT IN (SELECT code FROM blk)
+          |""".stripMargin
+
+      checkSparkAnswer(query)
+    }
+  }
 }
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
index 413ad7be..491f85da 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
@@ -664,16 +664,23 @@ object AuronConverters extends Logging {
     }
   }
 
+  @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
+  def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = 
exec.isNullAwareAntiJoin
+
+  @sparkver("3.0")
+  def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = false
+
   def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = {
     try {
-      val (leftKeys, rightKeys, joinType, buildSide, condition, left, right) = 
(
+      val (leftKeys, rightKeys, joinType, buildSide, condition, left, right, 
naaj) = (
         exec.leftKeys,
         exec.rightKeys,
         exec.joinType,
         exec.buildSide,
         exec.condition,
         exec.left,
-        exec.right)
+        exec.right,
+        isNullAwareAntiJoin(exec))
       logDebugPlanConversion(
         exec,
         Seq(
@@ -702,7 +709,8 @@ object AuronConverters extends Logging {
         buildSide match {
           case BuildLeft => BroadcastLeft
           case BuildRight => BroadcastRight
-        })
+        },
+        naaj)
 
     } catch {
       case e @ (_: NotImplementedError | _: Exception) =>
@@ -744,7 +752,8 @@ object AuronConverters extends Logging {
         buildSide match {
           case BuildLeft => BroadcastLeft
           case BuildRight => BroadcastRight
-        })
+        },
+        isNullAwareAntiJoin = false)
 
     } catch {
       case e @ (_: NotImplementedError | _: Exception) =>
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
index a192e198..a0dd37ae 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
@@ -86,7 +86,8 @@ abstract class Shims {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      broadcastSide: BroadcastSide): NativeBroadcastJoinBase
+      broadcastSide: BroadcastSide,
+      isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase
 
   def createNativeSortMergeJoinExec(
       left: SparkPlan,
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
index dabeba3f..3281947c 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
@@ -52,7 +52,8 @@ abstract class NativeBroadcastJoinBase(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     joinType: JoinType,
-    broadcastSide: BroadcastSide)
+    broadcastSide: BroadcastSide,
+    isNullAwareAntiJoin: Boolean)
     extends BinaryExecNode
     with NativeSupports {
 
@@ -174,6 +175,7 @@ abstract class NativeBroadcastJoinBase(
           .setJoinType(nativeJoinType)
           .setBroadcastSide(nativeBroadcastSide)
           .setCachedBuildHashMapId(cachedBuildHashMapId)
+          .setIsNullAwareAntiJoin(isNullAwareAntiJoin)
           .addAllOn(nativeJoinOn.asJava)
 
         
pb.PhysicalPlanNode.newBuilder().setBroadcastJoin(broadcastJoinExec).build()

Reply via email to