This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 11fc52d318 Use dedicated NullEquality enum instead of null_equals_null 
boolean (#16419)
11fc52d318 is described below

commit 11fc52d318ebdce2c683788563b39d76df04de70
Author: Tobias Schwarzinger <tobias.schwarzin...@tuwien.ac.at>
AuthorDate: Wed Jun 18 00:52:27 2025 +0200

    Use dedicated NullEquality enum instead of null_equals_null boolean (#16419)
    
    * Use dedicated NullEquality enum instead of null_equals_null boolean
    
    * Fix wrong operator mapping in hash_join
    
    * Add an example to the documentation
---
 datafusion/common/src/lib.rs                       |   2 +
 datafusion/common/src/null_equality.rs             |  46 +++
 datafusion/core/src/physical_planner.rs            |  10 +-
 datafusion/core/tests/execution/infinite_cancel.rs |  10 +-
 datafusion/core/tests/fuzz_cases/join_fuzz.rs      |   6 +-
 .../tests/physical_optimizer/join_selection.rs     |  20 +-
 .../physical_optimizer/projection_pushdown.rs      |   8 +-
 .../replace_with_order_preserving_variants.rs      |   4 +-
 .../core/tests/physical_optimizer/test_utils.rs    |   6 +-
 datafusion/expr/src/logical_plan/builder.rs        |  43 ++-
 datafusion/expr/src/logical_plan/plan.rs           |  50 +--
 datafusion/expr/src/logical_plan/tree_node.rs      |   8 +-
 datafusion/optimizer/src/eliminate_cross_join.rs   |  42 +--
 datafusion/optimizer/src/eliminate_outer_join.rs   |   2 +-
 .../optimizer/src/extract_equijoin_predicate.rs    |   6 +-
 datafusion/optimizer/src/filter_null_join_keys.rs  |   5 +-
 .../physical-optimizer/src/enforce_distribution.rs |  16 +-
 .../physical-optimizer/src/join_selection.rs       |   8 +-
 datafusion/physical-plan/src/joins/hash_join.rs    | 388 +++++++++++++++------
 .../physical-plan/src/joins/nested_loop_join.rs    |   2 +-
 .../physical-plan/src/joins/sort_merge_join.rs     | 116 +++---
 .../physical-plan/src/joins/symmetric_hash_join.rs |  56 +--
 datafusion/physical-plan/src/joins/test_utils.rs   |  10 +-
 .../proto-common/proto/datafusion_common.proto     |   5 +
 datafusion/proto-common/src/generated/pbjson.rs    |  71 ++++
 datafusion/proto-common/src/generated/prost.rs     |  26 ++
 datafusion/proto/proto/datafusion.proto            |   6 +-
 .../proto/src/generated/datafusion_proto_common.rs |  26 ++
 datafusion/proto/src/generated/pbjson.rs           |  84 ++---
 datafusion/proto/src/generated/prost.rs            |  12 +-
 datafusion/proto/src/logical_plan/from_proto.rs    |  13 +-
 datafusion/proto/src/logical_plan/mod.rs           |   6 +-
 datafusion/proto/src/logical_plan/to_proto.rs      |  11 +-
 datafusion/proto/src/physical_plan/mod.rs          |  24 +-
 .../proto/tests/cases/roundtrip_physical_plan.rs   |   6 +-
 datafusion/sqllogictest/test_files/joins.slt       |   2 +-
 .../src/logical_plan/consumer/rel/join_rel.rs      |  19 +-
 .../src/logical_plan/producer/rel/join.rs          |  11 +-
 38 files changed, 823 insertions(+), 363 deletions(-)

diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index 7b2c86d397..d89e08c7d4 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -46,6 +46,7 @@ pub mod file_options;
 pub mod format;
 pub mod hash_utils;
 pub mod instant;
+mod null_equality;
 pub mod parsers;
 pub mod pruning;
 pub mod rounding;
@@ -79,6 +80,7 @@ pub use functional_dependencies::{
 };
 use hashbrown::hash_map::DefaultHashBuilder;
 pub use join_type::{JoinConstraint, JoinSide, JoinType};
+pub use null_equality::NullEquality;
 pub use param_value::ParamValues;
 pub use scalar::{ScalarType, ScalarValue};
 pub use schema_reference::SchemaReference;
diff --git a/datafusion/common/src/null_equality.rs 
b/datafusion/common/src/null_equality.rs
new file mode 100644
index 0000000000..847fb09757
--- /dev/null
+++ b/datafusion/common/src/null_equality.rs
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Represents the behavior for null values when evaluating equality. 
Currently, its primary use
+/// case is to define the behavior of joins for null values.
+///
+/// # Examples
+///
+/// The following table shows the expected equality behavior for 
`NullEquality`.
+///
+/// | A    | B    | NullEqualsNothing | NullEqualsNull |
+/// |------|------|-------------------|----------------|
+/// | NULL | NULL | false             | true           |
+/// | NULL | 'b'  | false             | false          |
+/// | 'a'  | NULL | false             | false          |
+/// | 'a'  | 'b'  | false             | false          |
+///
+/// # Order
+///
+/// The order on this type represents the "restrictiveness" of the behavior. 
The more restrictive
+/// a behavior is, the fewer elements are considered to be equal to null.
+/// [NullEquality::NullEqualsNothing] represents the most restrictive behavior.
+///
+/// This mirrors the old order with `null_equals_null` booleans, as `false` 
indicated that
+/// `null != null`.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
+pub enum NullEquality {
+    /// Null is *not* equal to anything (`null != null`)
+    NullEqualsNothing,
+    /// Null is equal to null (`null == null`)
+    NullEqualsNull,
+}
diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index a8086b4645..14188f6bf0 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -901,12 +901,10 @@ impl DefaultPhysicalPlanner {
                 on: keys,
                 filter,
                 join_type,
-                null_equals_null,
+                null_equality,
                 schema: join_schema,
                 ..
             }) => {
-                let null_equals_null = *null_equals_null;
-
                 let [physical_left, physical_right] = children.two()?;
 
                 // If join has expression equijoin keys, add physical 
projection.
@@ -1127,7 +1125,7 @@ impl DefaultPhysicalPlanner {
                         join_filter,
                         *join_type,
                         vec![SortOptions::default(); join_on_len],
-                        null_equals_null,
+                        *null_equality,
                     )?)
                 } else if session_state.config().target_partitions() > 1
                     && session_state.config().repartition_joins()
@@ -1141,7 +1139,7 @@ impl DefaultPhysicalPlanner {
                         join_type,
                         None,
                         PartitionMode::Auto,
-                        null_equals_null,
+                        *null_equality,
                     )?)
                 } else {
                     Arc::new(HashJoinExec::try_new(
@@ -1152,7 +1150,7 @@ impl DefaultPhysicalPlanner {
                         join_type,
                         None,
                         PartitionMode::CollectLeft,
-                        null_equals_null,
+                        *null_equality,
                     )?)
                 };
 
diff --git a/datafusion/core/tests/execution/infinite_cancel.rs 
b/datafusion/core/tests/execution/infinite_cancel.rs
index 00c1f6b448..ea35dd367b 100644
--- a/datafusion/core/tests/execution/infinite_cancel.rs
+++ b/datafusion/core/tests/execution/infinite_cancel.rs
@@ -33,7 +33,7 @@ use datafusion::physical_plan::execution_plan::Boundedness;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::SessionContext;
 use datafusion_common::config::ConfigOptions;
-use datafusion_common::{JoinType, ScalarValue};
+use datafusion_common::{JoinType, NullEquality, ScalarValue};
 use datafusion_expr_common::operator::Operator::Gt;
 use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, Literal};
 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
@@ -467,7 +467,7 @@ async fn test_infinite_join_cancel(
         &JoinType::Inner,
         None,
         PartitionMode::CollectLeft,
-        true,
+        NullEquality::NullEqualsNull,
     )?);
 
     // 3) Wrap yields under each infinite leaf
@@ -550,7 +550,7 @@ async fn test_infinite_join_agg_cancel(
         &JoinType::Inner,
         None,
         PartitionMode::CollectLeft,
-        true,
+        NullEquality::NullEqualsNull,
     )?);
 
     // 3) Project only one column (“value” from the left side) because we just 
want to sum that
@@ -714,7 +714,7 @@ async fn 
test_infinite_hash_join_without_repartition_and_no_agg(
         /* output64 */ None,
         // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned 
channels.
         PartitionMode::CollectLeft,
-        /* build_left */ true,
+        /* build_left */ NullEquality::NullEqualsNull,
     )?);
 
     // 3) Do not apply InsertYieldExec—since there is no aggregation, 
InsertYieldExec would
@@ -796,7 +796,7 @@ async fn 
test_infinite_sort_merge_join_without_repartition_and_no_agg(
         /* filter */ None,
         JoinType::Inner,
         vec![SortOptions::new(true, false)], // ascending, nulls last
-        /* null_equal */ true,
+        /* null_equality */ NullEquality::NullEqualsNull,
     )?);
 
     // 3) Do not apply InsertYieldExec (no aggregation, no repartition → no 
built-in yields).
diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index 1a8064ac1e..7250a263d8 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -37,7 +37,7 @@ use datafusion::physical_plan::joins::{
     HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec,
 };
 use datafusion::prelude::{SessionConfig, SessionContext};
-use datafusion_common::ScalarValue;
+use datafusion_common::{NullEquality, ScalarValue};
 use datafusion_physical_expr::expressions::Literal;
 use datafusion_physical_expr::PhysicalExprRef;
 
@@ -504,7 +504,7 @@ impl JoinFuzzTestCase {
                 self.join_filter(),
                 self.join_type,
                 vec![SortOptions::default(); self.on_columns().len()],
-                false,
+                NullEquality::NullEqualsNothing,
             )
             .unwrap(),
         )
@@ -521,7 +521,7 @@ impl JoinFuzzTestCase {
                 &self.join_type,
                 None,
                 PartitionMode::Partitioned,
-                false,
+                NullEquality::NullEqualsNothing,
             )
             .unwrap(),
         )
diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs 
b/datafusion/core/tests/physical_optimizer/join_selection.rs
index d8c0c142f7..3477ac7712 100644
--- a/datafusion/core/tests/physical_optimizer/join_selection.rs
+++ b/datafusion/core/tests/physical_optimizer/join_selection.rs
@@ -25,8 +25,8 @@ use std::{
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::config::ConfigOptions;
-use datafusion_common::JoinSide;
 use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, 
ScalarValue};
+use datafusion_common::{JoinSide, NullEquality};
 use datafusion_common::{Result, Statistics};
 use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, 
TaskContext};
 use datafusion_expr::Operator;
@@ -222,7 +222,7 @@ async fn test_join_with_swap() {
             &JoinType::Left,
             None,
             PartitionMode::CollectLeft,
-            false,
+            NullEquality::NullEqualsNothing,
         )
         .unwrap(),
     );
@@ -284,7 +284,7 @@ async fn test_left_join_no_swap() {
             &JoinType::Left,
             None,
             PartitionMode::CollectLeft,
-            false,
+            NullEquality::NullEqualsNothing,
         )
         .unwrap(),
     );
@@ -333,7 +333,7 @@ async fn test_join_with_swap_semi() {
             &join_type,
             None,
             PartitionMode::Partitioned,
-            false,
+            NullEquality::NullEqualsNothing,
         )
         .unwrap();
 
@@ -408,7 +408,7 @@ async fn test_nested_join_swap() {
         &JoinType::Inner,
         None,
         PartitionMode::CollectLeft,
-        false,
+        NullEquality::NullEqualsNothing,
     )
     .unwrap();
     let child_schema = child_join.schema();
@@ -425,7 +425,7 @@ async fn test_nested_join_swap() {
         &JoinType::Left,
         None,
         PartitionMode::CollectLeft,
-        false,
+        NullEquality::NullEqualsNothing,
     )
     .unwrap();
 
@@ -464,7 +464,7 @@ async fn test_join_no_swap() {
             &JoinType::Inner,
             None,
             PartitionMode::CollectLeft,
-            false,
+            NullEquality::NullEqualsNothing,
         )
         .unwrap(),
     );
@@ -690,7 +690,7 @@ async fn test_hash_join_swap_on_joins_with_projections(
         &join_type,
         Some(projection),
         PartitionMode::Partitioned,
-        false,
+        NullEquality::NullEqualsNothing,
     )?);
 
     let swapped = join
@@ -851,7 +851,7 @@ fn check_join_partition_mode(
             &JoinType::Inner,
             None,
             PartitionMode::Auto,
-            false,
+            NullEquality::NullEqualsNothing,
         )
         .unwrap(),
     );
@@ -1498,7 +1498,7 @@ async fn test_join_with_maybe_swap_unbounded_case(t: 
TestCase) -> Result<()> {
         &t.initial_join_type,
         None,
         t.initial_mode,
-        false,
+        NullEquality::NullEqualsNothing,
     )?) as _;
 
     let optimized_join_plan = hash_join_swap_subrule(join, 
&ConfigOptions::new())?;
diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs 
b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs
index f2958deb57..1f8aad0f23 100644
--- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs
@@ -25,7 +25,7 @@ use datafusion::datasource::memory::MemorySourceConfig;
 use datafusion::datasource::physical_plan::CsvSource;
 use datafusion::datasource::source::DataSourceExec;
 use datafusion_common::config::ConfigOptions;
-use datafusion_common::{JoinSide, JoinType, Result, ScalarValue};
+use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue};
 use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
 use datafusion_execution::object_store::ObjectStoreUrl;
 use datafusion_execution::{SendableRecordBatchStream, TaskContext};
@@ -883,7 +883,7 @@ fn test_join_after_projection() -> Result<()> {
             ])),
         )),
         &JoinType::Inner,
-        true,
+        NullEquality::NullEqualsNull,
         None,
         None,
         StreamJoinPartitionMode::SinglePartition,
@@ -997,7 +997,7 @@ fn test_join_after_required_projection() -> Result<()> {
             ])),
         )),
         &JoinType::Inner,
-        true,
+        NullEquality::NullEqualsNull,
         None,
         None,
         StreamJoinPartitionMode::SinglePartition,
@@ -1158,7 +1158,7 @@ fn test_hash_join_after_projection() -> Result<()> {
         &JoinType::Inner,
         None,
         PartitionMode::Auto,
-        true,
+        NullEquality::NullEqualsNull,
     )?);
     let projection = Arc::new(ProjectionExec::try_new(
         vec![
diff --git 
a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs
 
b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs
index f7c134cae3..c9baa9a932 100644
--- 
a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs
+++ 
b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs
@@ -30,7 +30,7 @@ use arrow::compute::SortOptions;
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::tree_node::{TransformedResult, TreeNode};
-use datafusion_common::{assert_contains, Result};
+use datafusion_common::{assert_contains, NullEquality, Result};
 use datafusion_common::config::ConfigOptions;
 use datafusion_datasource::source::DataSourceExec;
 use datafusion_execution::TaskContext;
@@ -1171,7 +1171,7 @@ fn hash_join_exec(
             &JoinType::Inner,
             None,
             PartitionMode::Partitioned,
-            false,
+            NullEquality::NullEqualsNothing,
         )
         .unwrap(),
     )
diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs 
b/datafusion/core/tests/physical_optimizer/test_utils.rs
index ebb623dc61..c91a70989b 100644
--- a/datafusion/core/tests/physical_optimizer/test_utils.rs
+++ b/datafusion/core/tests/physical_optimizer/test_utils.rs
@@ -33,7 +33,7 @@ use datafusion_common::config::ConfigOptions;
 use datafusion_common::stats::Precision;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
-use datafusion_common::{ColumnStatistics, JoinType, Result, Statistics};
+use datafusion_common::{ColumnStatistics, JoinType, NullEquality, Result, 
Statistics};
 use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
 use datafusion_execution::object_store::ObjectStoreUrl;
 use datafusion_execution::{SendableRecordBatchStream, TaskContext};
@@ -190,7 +190,7 @@ pub fn sort_merge_join_exec(
             None,
             *join_type,
             vec![SortOptions::default(); join_on.len()],
-            false,
+            NullEquality::NullEqualsNothing,
         )
         .unwrap(),
     )
@@ -236,7 +236,7 @@ pub fn hash_join_exec(
         join_type,
         None,
         PartitionMode::Partitioned,
-        true,
+        NullEquality::NullEqualsNull,
     )?))
 }
 
diff --git a/datafusion/expr/src/logical_plan/builder.rs 
b/datafusion/expr/src/logical_plan/builder.rs
index fbd1fdadc4..93dd6c2b89 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -56,7 +56,8 @@ use datafusion_common::file_options::file_type::FileType;
 use datafusion_common::{
     exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
     plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
-    DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, 
UnnestOptions,
+    DataFusionError, NullEquality, Result, ScalarValue, TableReference, 
ToDFSchema,
+    UnnestOptions,
 };
 use datafusion_expr_common::type_coercion::binary::type_union_resolution;
 
@@ -903,7 +904,13 @@ impl LogicalPlanBuilder {
         join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
         filter: Option<Expr>,
     ) -> Result<Self> {
-        self.join_detailed(right, join_type, join_keys, filter, false)
+        self.join_detailed(
+            right,
+            join_type,
+            join_keys,
+            filter,
+            NullEquality::NullEqualsNothing,
+        )
     }
 
     /// Apply a join using the specified expressions.
@@ -959,7 +966,7 @@ impl LogicalPlanBuilder {
             join_type,
             (Vec::<Column>::new(), Vec::<Column>::new()),
             filter,
-            false,
+            NullEquality::NullEqualsNothing,
         )
     }
 
@@ -987,16 +994,14 @@ impl LogicalPlanBuilder {
     /// The behavior is the same as [`join`](Self::join) except that it allows
     /// specifying the null equality behavior.
     ///
-    /// If `null_equals_null=true`, rows where both join keys are `null` will 
be
-    /// emitted. Otherwise rows where either or both join keys are `null` will 
be
-    /// omitted.
+    /// The `null_equality` dictates how `null` values are joined.
     pub fn join_detailed(
         self,
         right: LogicalPlan,
         join_type: JoinType,
         join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
         filter: Option<Expr>,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<Self> {
         if join_keys.0.len() != join_keys.1.len() {
             return plan_err!("left_keys and right_keys were not the same 
length");
@@ -1113,7 +1118,7 @@ impl LogicalPlanBuilder {
             join_type,
             join_constraint: JoinConstraint::On,
             schema: DFSchemaRef::new(join_schema),
-            null_equals_null,
+            null_equality,
         })))
     }
 
@@ -1186,7 +1191,7 @@ impl LogicalPlanBuilder {
                 filters,
                 join_type,
                 JoinConstraint::Using,
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             Ok(Self::new(LogicalPlan::Join(join)))
@@ -1202,7 +1207,7 @@ impl LogicalPlanBuilder {
             None,
             JoinType::Inner,
             JoinConstraint::On,
-            false,
+            NullEquality::NullEqualsNothing,
         )?;
 
         Ok(Self::new(LogicalPlan::Join(join)))
@@ -1340,12 +1345,24 @@ impl LogicalPlanBuilder {
             .unzip();
         if is_all {
             LogicalPlanBuilder::from(left_plan)
-                .join_detailed(right_plan, join_type, join_keys, None, true)?
+                .join_detailed(
+                    right_plan,
+                    join_type,
+                    join_keys,
+                    None,
+                    NullEquality::NullEqualsNull,
+                )?
                 .build()
         } else {
             LogicalPlanBuilder::from(left_plan)
                 .distinct()?
-                .join_detailed(right_plan, join_type, join_keys, None, true)?
+                .join_detailed(
+                    right_plan,
+                    join_type,
+                    join_keys,
+                    None,
+                    NullEquality::NullEqualsNull,
+                )?
                 .build()
         }
     }
@@ -1423,7 +1440,7 @@ impl LogicalPlanBuilder {
             filter,
             join_type,
             JoinConstraint::On,
-            false,
+            NullEquality::NullEqualsNothing,
         )?;
 
         Ok(Self::new(LogicalPlan::Join(join)))
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 4ac2d182aa..876c14f100 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -56,8 +56,8 @@ use datafusion_common::tree_node::{
 use datafusion_common::{
     aggregate_functional_dependencies, internal_err, plan_err, Column, 
Constraints,
     DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence,
-    FunctionalDependencies, ParamValues, Result, ScalarValue, Spans, 
TableReference,
-    UnnestOptions,
+    FunctionalDependencies, NullEquality, ParamValues, Result, ScalarValue, 
Spans,
+    TableReference, UnnestOptions,
 };
 use indexmap::IndexSet;
 
@@ -657,7 +657,7 @@ impl LogicalPlan {
                 join_constraint,
                 on,
                 schema: _,
-                null_equals_null,
+                null_equality,
             }) => {
                 let schema =
                     build_join_schema(left.schema(), right.schema(), 
&join_type)?;
@@ -678,7 +678,7 @@ impl LogicalPlan {
                     on: new_on,
                     filter,
                     schema: DFSchemaRef::new(schema),
-                    null_equals_null,
+                    null_equality,
                 }))
             }
             LogicalPlan::Subquery(_) => Ok(self),
@@ -896,7 +896,7 @@ impl LogicalPlan {
                 join_type,
                 join_constraint,
                 on,
-                null_equals_null,
+                null_equality,
                 ..
             }) => {
                 let (left, right) = self.only_two_inputs(inputs)?;
@@ -935,7 +935,7 @@ impl LogicalPlan {
                     on: new_on,
                     filter: filter_expr,
                     schema: DFSchemaRef::new(schema),
-                    null_equals_null: *null_equals_null,
+                    null_equality: *null_equality,
                 }))
             }
             LogicalPlan::Subquery(Subquery {
@@ -3708,8 +3708,8 @@ pub struct Join {
     pub join_constraint: JoinConstraint,
     /// The output schema, containing fields from the left and right inputs
     pub schema: DFSchemaRef,
-    /// If null_equals_null is true, null == null else null != null
-    pub null_equals_null: bool,
+    /// Defines the null equality for the join.
+    pub null_equality: NullEquality,
 }
 
 impl Join {
@@ -3726,7 +3726,7 @@ impl Join {
     /// * `filter` - Optional filter expression (for non-equijoin conditions)
     /// * `join_type` - Type of join (Inner, Left, Right, etc.)
     /// * `join_constraint` - Join constraint (On, Using)
-    /// * `null_equals_null` - Whether NULL = NULL in join comparisons
+    /// * `null_equality` - How to handle nulls in join comparisons
     ///
     /// # Returns
     ///
@@ -3738,7 +3738,7 @@ impl Join {
         filter: Option<Expr>,
         join_type: JoinType,
         join_constraint: JoinConstraint,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<Self> {
         let join_schema = build_join_schema(left.schema(), right.schema(), 
&join_type)?;
 
@@ -3750,7 +3750,7 @@ impl Join {
             join_type,
             join_constraint,
             schema: Arc::new(join_schema),
-            null_equals_null,
+            null_equality,
         })
     }
 
@@ -3783,7 +3783,7 @@ impl Join {
             join_type: original_join.join_type,
             join_constraint: original_join.join_constraint,
             schema: Arc::new(join_schema),
-            null_equals_null: original_join.null_equals_null,
+            null_equality: original_join.null_equality,
         })
     }
 }
@@ -3805,8 +3805,8 @@ impl PartialOrd for Join {
             pub join_type: &'a JoinType,
             /// Join constraint
             pub join_constraint: &'a JoinConstraint,
-            /// If null_equals_null is true, null == null else null != null
-            pub null_equals_null: &'a bool,
+            /// The null handling behavior for equalities
+            pub null_equality: &'a NullEquality,
         }
         let comparable_self = ComparableJoin {
             left: &self.left,
@@ -3815,7 +3815,7 @@ impl PartialOrd for Join {
             filter: &self.filter,
             join_type: &self.join_type,
             join_constraint: &self.join_constraint,
-            null_equals_null: &self.null_equals_null,
+            null_equality: &self.null_equality,
         };
         let comparable_other = ComparableJoin {
             left: &other.left,
@@ -3824,7 +3824,7 @@ impl PartialOrd for Join {
             filter: &other.filter,
             join_type: &other.join_type,
             join_constraint: &other.join_constraint,
-            null_equals_null: &other.null_equals_null,
+            null_equality: &other.null_equality,
         };
         comparable_self.partial_cmp(&comparable_other)
     }
@@ -4895,7 +4895,7 @@ mod tests {
                 join_type: JoinType::Inner,
                 join_constraint: JoinConstraint::On,
                 schema: Arc::new(left_schema.join(&right_schema)?),
-                null_equals_null: false,
+                null_equality: NullEquality::NullEqualsNothing,
             }))
         }
 
@@ -5006,7 +5006,7 @@ mod tests {
                 Some(col("t1.b").gt(col("t2.b"))),
                 join_type,
                 JoinConstraint::On,
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             match join_type {
@@ -5116,7 +5116,7 @@ mod tests {
             assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
             assert_eq!(join.join_type, join_type);
             assert_eq!(join.join_constraint, JoinConstraint::On);
-            assert!(!join.null_equals_null);
+            assert_eq!(join.null_equality, NullEquality::NullEqualsNothing);
         }
 
         Ok(())
@@ -5151,7 +5151,7 @@ mod tests {
                 None,
                 JoinType::Inner,
                 JoinConstraint::Using,
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             let fields = join.schema.fields();
@@ -5202,7 +5202,7 @@ mod tests {
                 Some(col("t1.value").lt(col("t2.value"))), // Non-equi filter 
condition
                 JoinType::Inner,
                 JoinConstraint::On,
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             let fields = join.schema.fields();
@@ -5251,10 +5251,10 @@ mod tests {
                 None,
                 JoinType::Inner,
                 JoinConstraint::On,
-                true,
+                NullEquality::NullEqualsNull,
             )?;
 
-            assert!(join.null_equals_null);
+            assert_eq!(join.null_equality, NullEquality::NullEqualsNull);
         }
 
         Ok(())
@@ -5293,7 +5293,7 @@ mod tests {
                 Some(col("t1.value").gt(lit(5.0))),
                 join_type,
                 JoinConstraint::On,
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             let fields = join.schema.fields();
@@ -5332,7 +5332,7 @@ mod tests {
             None,
             JoinType::Inner,
             JoinConstraint::Using,
-            false,
+            NullEquality::NullEqualsNothing,
         )?;
 
         assert_eq!(
diff --git a/datafusion/expr/src/logical_plan/tree_node.rs 
b/datafusion/expr/src/logical_plan/tree_node.rs
index 7f5b7e07ed..527248ad39 100644
--- a/datafusion/expr/src/logical_plan/tree_node.rs
+++ b/datafusion/expr/src/logical_plan/tree_node.rs
@@ -132,7 +132,7 @@ impl TreeNode for LogicalPlan {
                 join_type,
                 join_constraint,
                 schema,
-                null_equals_null,
+                null_equality,
             }) => (left, right).map_elements(f)?.update_data(|(left, right)| {
                 LogicalPlan::Join(Join {
                     left,
@@ -142,7 +142,7 @@ impl TreeNode for LogicalPlan {
                     join_type,
                     join_constraint,
                     schema,
-                    null_equals_null,
+                    null_equality,
                 })
             }),
             LogicalPlan::Limit(Limit { skip, fetch, input }) => input
@@ -561,7 +561,7 @@ impl LogicalPlan {
                 join_type,
                 join_constraint,
                 schema,
-                null_equals_null,
+                null_equality,
             }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| {
                 LogicalPlan::Join(Join {
                     left,
@@ -571,7 +571,7 @@ impl LogicalPlan {
                     join_type,
                     join_constraint,
                     schema,
-                    null_equals_null,
+                    null_equality,
                 })
             }),
             LogicalPlan::Sort(Sort { expr, input, fetch }) => expr
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs 
b/datafusion/optimizer/src/eliminate_cross_join.rs
index deefaef2c0..ae1d7df46d 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -21,7 +21,7 @@ use std::sync::Arc;
 
 use crate::join_key_set::JoinKeySet;
 use datafusion_common::tree_node::{Transformed, TreeNode};
-use datafusion_common::Result;
+use datafusion_common::{NullEquality, Result};
 use datafusion_expr::expr::{BinaryExpr, Expr};
 use datafusion_expr::logical_plan::{
     Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
@@ -89,7 +89,7 @@ impl OptimizerRule for EliminateCrossJoin {
         let mut possible_join_keys = JoinKeySet::new();
         let mut all_inputs: Vec<LogicalPlan> = vec![];
         let mut all_filters: Vec<Expr> = vec![];
-        let mut null_equals_null = false;
+        let mut null_equality = NullEquality::NullEqualsNothing;
 
         let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
             // if input isn't a join that can potentially be rewritten
@@ -115,9 +115,9 @@ impl OptimizerRule for EliminateCrossJoin {
                 input, predicate, ..
             } = filter;
 
-            // Extract null_equals_null setting from the input join
+            // Extract null_equality setting from the input join
             if let LogicalPlan::Join(join) = input.as_ref() {
-                null_equals_null = join.null_equals_null;
+                null_equality = join.null_equality;
             }
 
             flatten_join_inputs(
@@ -133,7 +133,7 @@ impl OptimizerRule for EliminateCrossJoin {
             match plan {
                 LogicalPlan::Join(Join {
                     join_type: JoinType::Inner,
-                    null_equals_null: original_null_equals_null,
+                    null_equality: original_null_equality,
                     ..
                 }) => {
                     if !can_flatten_join_inputs(&plan) {
@@ -145,7 +145,7 @@ impl OptimizerRule for EliminateCrossJoin {
                         &mut all_inputs,
                         &mut all_filters,
                     )?;
-                    null_equals_null = original_null_equals_null;
+                    null_equality = original_null_equality;
                     None
                 }
                 _ => {
@@ -164,7 +164,7 @@ impl OptimizerRule for EliminateCrossJoin {
                 &mut all_inputs,
                 &possible_join_keys,
                 &mut all_join_keys,
-                null_equals_null,
+                null_equality,
             )?;
         }
 
@@ -302,7 +302,7 @@ fn find_inner_join(
     rights: &mut Vec<LogicalPlan>,
     possible_join_keys: &JoinKeySet,
     all_join_keys: &mut JoinKeySet,
-    null_equals_null: bool,
+    null_equality: NullEquality,
 ) -> Result<LogicalPlan> {
     for (i, right_input) in rights.iter().enumerate() {
         let mut join_keys = vec![];
@@ -341,7 +341,7 @@ fn find_inner_join(
                 on: join_keys,
                 filter: None,
                 schema: join_schema,
-                null_equals_null,
+                null_equality,
             }));
         }
     }
@@ -363,7 +363,7 @@ fn find_inner_join(
         filter: None,
         join_type: JoinType::Inner,
         join_constraint: JoinConstraint::On,
-        null_equals_null,
+        null_equality,
     }))
 }
 
@@ -1348,11 +1348,11 @@ mod tests {
     }
 
     #[test]
-    fn preserve_null_equals_null_setting() -> Result<()> {
+    fn preserve_null_equality_setting() -> Result<()> {
         let t1 = test_table_scan_with_name("t1")?;
         let t2 = test_table_scan_with_name("t2")?;
 
-        // Create an inner join with null_equals_null: true
+        // Create an inner join with NullEquality::NullEqualsNull
         let join_schema = Arc::new(build_join_schema(
             t1.schema(),
             t2.schema(),
@@ -1367,7 +1367,7 @@ mod tests {
             on: vec![],
             filter: None,
             schema: join_schema,
-            null_equals_null: true, // Set to true to test preservation
+            null_equality: NullEquality::NullEqualsNull, // Test preservation
         });
 
         // Apply filter that can create join conditions
@@ -1382,31 +1382,31 @@ mod tests {
         let rule = EliminateCrossJoin::new();
         let optimized_plan = rule.rewrite(plan, 
&OptimizerContext::new())?.data;
 
-        // Verify that null_equals_null is preserved in the optimized plan
-        fn check_null_equals_null_preserved(plan: &LogicalPlan) -> bool {
+        // Verify that null_equality is preserved in the optimized plan
+        fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
             match plan {
                 LogicalPlan::Join(join) => {
-                    // All joins in the optimized plan should preserve 
null_equals_null: true
-                    if !join.null_equals_null {
+                    // All joins in the optimized plan should preserve null 
equality
+                    if join.null_equality == NullEquality::NullEqualsNothing {
                         return false;
                     }
                     // Recursively check child plans
                     plan.inputs()
                         .iter()
-                        .all(|input| check_null_equals_null_preserved(input))
+                        .all(|input| check_null_equality_preserved(input))
                 }
                 _ => {
                     // Recursively check child plans for non-join nodes
                     plan.inputs()
                         .iter()
-                        .all(|input| check_null_equals_null_preserved(input))
+                        .all(|input| check_null_equality_preserved(input))
                 }
             }
         }
 
         assert!(
-            check_null_equals_null_preserved(&optimized_plan),
-            "null_equals_null setting should be preserved after optimization"
+            check_null_equality_preserved(&optimized_plan),
+            "null_equality setting should be preserved after optimization"
         );
 
         Ok(())
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs 
b/datafusion/optimizer/src/eliminate_outer_join.rs
index 621086e4a2..45877642f2 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -118,7 +118,7 @@ impl OptimizerRule for EliminateOuterJoin {
                         on: join.on.clone(),
                         filter: join.filter.clone(),
                         schema: Arc::clone(&join.schema),
-                        null_equals_null: join.null_equals_null,
+                        null_equality: join.null_equality,
                     }));
                     Filter::try_new(filter.predicate, new_join)
                         .map(|f| Transformed::yes(LogicalPlan::Filter(f)))
diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs 
b/datafusion/optimizer/src/extract_equijoin_predicate.rs
index a07b50ade5..55cf33ef43 100644
--- a/datafusion/optimizer/src/extract_equijoin_predicate.rs
+++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs
@@ -75,7 +75,7 @@ impl OptimizerRule for ExtractEquijoinPredicate {
                 join_type,
                 join_constraint,
                 schema,
-                null_equals_null,
+                null_equality,
             }) => {
                 let left_schema = left.schema();
                 let right_schema = right.schema();
@@ -92,7 +92,7 @@ impl OptimizerRule for ExtractEquijoinPredicate {
                         join_type,
                         join_constraint,
                         schema,
-                        null_equals_null,
+                        null_equality,
                     })))
                 } else {
                     Ok(Transformed::no(LogicalPlan::Join(Join {
@@ -103,7 +103,7 @@ impl OptimizerRule for ExtractEquijoinPredicate {
                         join_type,
                         join_constraint,
                         schema,
-                        null_equals_null,
+                        null_equality,
                     })))
                 }
             }
diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs 
b/datafusion/optimizer/src/filter_null_join_keys.rs
index 14a424b326..8ad7fa53c0 100644
--- a/datafusion/optimizer/src/filter_null_join_keys.rs
+++ b/datafusion/optimizer/src/filter_null_join_keys.rs
@@ -21,7 +21,7 @@ use crate::optimizer::ApplyOrder;
 use crate::push_down_filter::on_lr_is_preserved;
 use crate::{OptimizerConfig, OptimizerRule};
 use datafusion_common::tree_node::Transformed;
-use datafusion_common::Result;
+use datafusion_common::{NullEquality, Result};
 use datafusion_expr::utils::conjunction;
 use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan};
 use std::sync::Arc;
@@ -51,7 +51,8 @@ impl OptimizerRule for FilterNullJoinKeys {
         }
         match plan {
             LogicalPlan::Join(mut join)
-                if !join.on.is_empty() && !join.null_equals_null =>
+                if !join.on.is_empty()
+                    && join.null_equality == NullEquality::NullEqualsNothing =>
             {
                 let (left_preserved, right_preserved) =
                     on_lr_is_preserved(join.join_type);
diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs 
b/datafusion/physical-optimizer/src/enforce_distribution.rs
index 566cf2f3a2..39eb557ea6 100644
--- a/datafusion/physical-optimizer/src/enforce_distribution.rs
+++ b/datafusion/physical-optimizer/src/enforce_distribution.rs
@@ -295,7 +295,7 @@ pub fn adjust_input_keys_ordering(
         join_type,
         projection,
         mode,
-        null_equals_null,
+        null_equality,
         ..
     }) = plan.as_any().downcast_ref::<HashJoinExec>()
     {
@@ -314,7 +314,7 @@ pub fn adjust_input_keys_ordering(
                         // TODO: although projection is not used in the join 
here, because projection pushdown is after enforce_distribution. Maybe we need 
to handle it later. Same as filter.
                         projection.clone(),
                         PartitionMode::Partitioned,
-                        *null_equals_null,
+                        *null_equality,
                     )
                     .map(|e| Arc::new(e) as _)
                 };
@@ -364,7 +364,7 @@ pub fn adjust_input_keys_ordering(
         filter,
         join_type,
         sort_options,
-        null_equals_null,
+        null_equality,
         ..
     }) = plan.as_any().downcast_ref::<SortMergeJoinExec>()
     {
@@ -379,7 +379,7 @@ pub fn adjust_input_keys_ordering(
                 filter.clone(),
                 *join_type,
                 new_conditions.1,
-                *null_equals_null,
+                *null_equality,
             )
             .map(|e| Arc::new(e) as _)
         };
@@ -616,7 +616,7 @@ pub fn reorder_join_keys_to_inputs(
         join_type,
         projection,
         mode,
-        null_equals_null,
+        null_equality,
         ..
     }) = plan_any.downcast_ref::<HashJoinExec>()
     {
@@ -642,7 +642,7 @@ pub fn reorder_join_keys_to_inputs(
                     join_type,
                     projection.clone(),
                     PartitionMode::Partitioned,
-                    *null_equals_null,
+                    *null_equality,
                 )?));
             }
         }
@@ -653,7 +653,7 @@ pub fn reorder_join_keys_to_inputs(
         filter,
         join_type,
         sort_options,
-        null_equals_null,
+        null_equality,
         ..
     }) = plan_any.downcast_ref::<SortMergeJoinExec>()
     {
@@ -681,7 +681,7 @@ pub fn reorder_join_keys_to_inputs(
                     filter.clone(),
                     *join_type,
                     new_sort_options,
-                    *null_equals_null,
+                    *null_equality,
                 )
                 .map(|smj| Arc::new(smj) as _);
             }
diff --git a/datafusion/physical-optimizer/src/join_selection.rs 
b/datafusion/physical-optimizer/src/join_selection.rs
index 27eed70241..dc22033214 100644
--- a/datafusion/physical-optimizer/src/join_selection.rs
+++ b/datafusion/physical-optimizer/src/join_selection.rs
@@ -245,7 +245,7 @@ pub(crate) fn try_collect_left(
                     hash_join.join_type(),
                     hash_join.projection.clone(),
                     PartitionMode::CollectLeft,
-                    hash_join.null_equals_null(),
+                    hash_join.null_equality(),
                 )?)))
             }
         }
@@ -257,7 +257,7 @@ pub(crate) fn try_collect_left(
             hash_join.join_type(),
             hash_join.projection.clone(),
             PartitionMode::CollectLeft,
-            hash_join.null_equals_null(),
+            hash_join.null_equality(),
         )?))),
         (false, true) => {
             if hash_join.join_type().supports_swap() {
@@ -292,7 +292,7 @@ pub(crate) fn partitioned_hash_join(
             hash_join.join_type(),
             hash_join.projection.clone(),
             PartitionMode::Partitioned,
-            hash_join.null_equals_null(),
+            hash_join.null_equality(),
         )?))
     }
 }
@@ -474,7 +474,7 @@ fn hash_join_convert_symmetric_subrule(
                 hash_join.on().to_vec(),
                 hash_join.filter().cloned(),
                 hash_join.join_type(),
-                hash_join.null_equals_null(),
+                hash_join.null_equality(),
                 left_order,
                 right_order,
                 mode,
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index 5034a199e2..8c4241be72 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -70,7 +70,7 @@ use arrow::util::bit_util;
 use datafusion_common::utils::memory::estimate_memory_size;
 use datafusion_common::{
     internal_datafusion_err, internal_err, plan_err, project_schema, 
DataFusionError,
-    JoinSide, JoinType, Result,
+    JoinSide, JoinType, NullEquality, Result,
 };
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::TaskContext;
@@ -353,11 +353,8 @@ pub struct HashJoinExec {
     pub projection: Option<Vec<usize>>,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
-    /// Null matching behavior: If `null_equals_null` is true, rows that have
-    /// `null`s in both left and right equijoin columns will be matched.
-    /// Otherwise, rows that have `null`s in the join columns will not be
-    /// matched and thus will not appear in the output.
-    pub null_equals_null: bool,
+    /// The equality null-handling behavior of the join algorithm.
+    pub null_equality: NullEquality,
     /// Cache holding plan properties like equivalences, output partitioning 
etc.
     cache: PlanProperties,
 }
@@ -376,7 +373,7 @@ impl HashJoinExec {
         join_type: &JoinType,
         projection: Option<Vec<usize>>,
         partition_mode: PartitionMode,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<Self> {
         let left_schema = left.schema();
         let right_schema = right.schema();
@@ -419,7 +416,7 @@ impl HashJoinExec {
             metrics: ExecutionPlanMetricsSet::new(),
             projection,
             column_indices,
-            null_equals_null,
+            null_equality,
             cache,
         })
     }
@@ -460,9 +457,9 @@ impl HashJoinExec {
         &self.mode
     }
 
-    /// Get null_equals_null
-    pub fn null_equals_null(&self) -> bool {
-        self.null_equals_null
+    /// Get null_equality
+    pub fn null_equality(&self) -> NullEquality {
+        self.null_equality
     }
 
     /// Calculate order preservation flags for this hash join.
@@ -510,7 +507,7 @@ impl HashJoinExec {
             &self.join_type,
             projection,
             self.mode,
-            self.null_equals_null,
+            self.null_equality,
         )
     }
 
@@ -619,7 +616,7 @@ impl HashJoinExec {
                 self.join_type(),
             ),
             partition_mode,
-            self.null_equals_null(),
+            self.null_equality(),
         )?;
         // In case of anti / semi joins or if there is embedded projection in 
HashJoinExec, output column order is preserved, no need to add projection again
         if matches!(
@@ -767,7 +764,7 @@ impl ExecutionPlan for HashJoinExec {
             &self.join_type,
             self.projection.clone(),
             self.mode,
-            self.null_equals_null,
+            self.null_equality,
         )?))
     }
 
@@ -871,7 +868,7 @@ impl ExecutionPlan for HashJoinExec {
             column_indices: column_indices_after_projection,
             random_state: self.random_state.clone(),
             join_metrics,
-            null_equals_null: self.null_equals_null,
+            null_equality: self.null_equality,
             state: HashJoinStreamState::WaitBuildSide,
             build_side: BuildSide::Initial(BuildSideInitialState { left_fut }),
             batch_size,
@@ -940,7 +937,7 @@ impl ExecutionPlan for HashJoinExec {
                 // Returned early if projection is not None
                 None,
                 *self.partition_mode(),
-                self.null_equals_null,
+                self.null_equality,
             )?)))
         } else {
             try_embed_projection(projection, self)
@@ -1231,8 +1228,8 @@ struct HashJoinStream {
     join_metrics: BuildProbeJoinMetrics,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
-    /// If null_equals_null is true, null == null else null != null
-    null_equals_null: bool,
+    /// Defines the null equality for the join.
+    null_equality: NullEquality,
     /// State of the stream
     state: HashJoinStreamState,
     /// Build side
@@ -1304,7 +1301,7 @@ fn lookup_join_hashmap(
     build_hashmap: &JoinHashMap,
     build_side_values: &[ArrayRef],
     probe_side_values: &[ArrayRef],
-    null_equals_null: bool,
+    null_equality: NullEquality,
     hashes_buffer: &[u64],
     limit: usize,
     offset: JoinHashMapOffset,
@@ -1320,7 +1317,7 @@ fn lookup_join_hashmap(
         &probe_indices,
         build_side_values,
         probe_side_values,
-        null_equals_null,
+        null_equality,
     )?;
 
     Ok((build_indices, probe_indices, next_offset))
@@ -1330,22 +1327,21 @@ fn lookup_join_hashmap(
 fn eq_dyn_null(
     left: &dyn Array,
     right: &dyn Array,
-    null_equals_null: bool,
+    null_equality: NullEquality,
 ) -> Result<BooleanArray, ArrowError> {
     // Nested datatypes cannot use the underlying not_distinct/eq function and 
must use a special
     // implementation
     // <https://github.com/apache/datafusion/issues/10749>
     if left.data_type().is_nested() {
-        let op = if null_equals_null {
-            Operator::IsNotDistinctFrom
-        } else {
-            Operator::Eq
+        let op = match null_equality {
+            NullEquality::NullEqualsNothing => Operator::Eq,
+            NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
         };
         return Ok(compare_op_for_nested(op, &left, &right)?);
     }
-    match (left.data_type(), right.data_type()) {
-        _ if null_equals_null => not_distinct(&left, &right),
-        _ => eq(&left, &right),
+    match null_equality {
+        NullEquality::NullEqualsNothing => eq(&left, &right),
+        NullEquality::NullEqualsNull => not_distinct(&left, &right),
     }
 }
 
@@ -1354,7 +1350,7 @@ pub fn equal_rows_arr(
     indices_right: &UInt32Array,
     left_arrays: &[ArrayRef],
     right_arrays: &[ArrayRef],
-    null_equals_null: bool,
+    null_equality: NullEquality,
 ) -> Result<(UInt64Array, UInt32Array)> {
     let mut iter = left_arrays.iter().zip(right_arrays.iter());
 
@@ -1367,7 +1363,7 @@ pub fn equal_rows_arr(
     let arr_left = take(first_left.as_ref(), indices_left, None)?;
     let arr_right = take(first_right.as_ref(), indices_right, None)?;
 
-    let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, 
null_equals_null)?;
+    let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, 
null_equality)?;
 
     // Use map and try_fold to iterate over the remaining pairs of arrays.
     // In each iteration, take is used on the pair of arrays and their 
equality is determined.
@@ -1376,7 +1372,7 @@ pub fn equal_rows_arr(
         .map(|(left, right)| {
             let arr_left = take(left.as_ref(), indices_left, None)?;
             let arr_right = take(right.as_ref(), indices_right, None)?;
-            eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), 
null_equals_null)
+            eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality)
         })
         .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
 
@@ -1496,7 +1492,7 @@ impl HashJoinStream {
             build_side.left_data.hash_map(),
             build_side.left_data.values(),
             &state.values,
-            self.null_equals_null,
+            self.null_equality,
             &self.hashes_buffer,
             self.batch_size,
             state.offset,
@@ -1726,7 +1722,7 @@ mod tests {
         right: Arc<dyn ExecutionPlan>,
         on: JoinOn,
         join_type: &JoinType,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<HashJoinExec> {
         HashJoinExec::try_new(
             left,
@@ -1736,7 +1732,7 @@ mod tests {
             join_type,
             None,
             PartitionMode::CollectLeft,
-            null_equals_null,
+            null_equality,
         )
     }
 
@@ -1746,7 +1742,7 @@ mod tests {
         on: JoinOn,
         filter: JoinFilter,
         join_type: &JoinType,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<HashJoinExec> {
         HashJoinExec::try_new(
             left,
@@ -1756,7 +1752,7 @@ mod tests {
             join_type,
             None,
             PartitionMode::CollectLeft,
-            null_equals_null,
+            null_equality,
         )
     }
 
@@ -1765,10 +1761,10 @@ mod tests {
         right: Arc<dyn ExecutionPlan>,
         on: JoinOn,
         join_type: &JoinType,
-        null_equals_null: bool,
+        null_equality: NullEquality,
         context: Arc<TaskContext>,
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
-        let join = join(left, right, on, join_type, null_equals_null)?;
+        let join = join(left, right, on, join_type, null_equality)?;
         let columns_header = columns(&join.schema());
 
         let stream = join.execute(0, context)?;
@@ -1782,7 +1778,7 @@ mod tests {
         right: Arc<dyn ExecutionPlan>,
         on: JoinOn,
         join_type: &JoinType,
-        null_equals_null: bool,
+        null_equality: NullEquality,
         context: Arc<TaskContext>,
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
         join_collect_with_partition_mode(
@@ -1791,7 +1787,7 @@ mod tests {
             on,
             join_type,
             PartitionMode::Partitioned,
-            null_equals_null,
+            null_equality,
             context,
         )
         .await
@@ -1803,7 +1799,7 @@ mod tests {
         on: JoinOn,
         join_type: &JoinType,
         partition_mode: PartitionMode,
-        null_equals_null: bool,
+        null_equality: NullEquality,
         context: Arc<TaskContext>,
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
         let partition_count = 4;
@@ -1853,7 +1849,7 @@ mod tests {
             join_type,
             None,
             partition_mode,
-            null_equals_null,
+            null_equality,
         )?;
 
         let columns = columns(&join.schema());
@@ -1898,7 +1894,7 @@ mod tests {
             Arc::clone(&right),
             on.clone(),
             &JoinType::Inner,
-            false,
+            NullEquality::NullEqualsNothing,
             task_ctx,
         )
         .await?;
@@ -1945,7 +1941,7 @@ mod tests {
             Arc::clone(&right),
             on.clone(),
             &JoinType::Inner,
-            false,
+            NullEquality::NullEqualsNothing,
             task_ctx,
         )
         .await?;
@@ -1985,8 +1981,15 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) =
-            join_collect(left, right, on, &JoinType::Inner, false, 
task_ctx).await?;
+        let (columns, batches) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
 
@@ -2024,8 +2027,15 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) =
-            join_collect(left, right, on, &JoinType::Inner, false, 
task_ctx).await?;
+        let (columns, batches) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
 
@@ -2071,8 +2081,15 @@ mod tests {
             ),
         ];
 
-        let (columns, batches) =
-            join_collect(left, right, on, &JoinType::Inner, false, 
task_ctx).await?;
+        let (columns, batches) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
 
@@ -2142,8 +2159,15 @@ mod tests {
             ),
         ];
 
-        let (columns, batches) =
-            join_collect(left, right, on, &JoinType::Inner, false, 
task_ctx).await?;
+        let (columns, batches) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
 
@@ -2208,8 +2232,15 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) =
-            join_collect(left, right, on, &JoinType::Inner, false, 
task_ctx).await?;
+        let (columns, batches) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
 
@@ -2258,7 +2289,13 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let join = join(left, right, on, &JoinType::Inner, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -2351,7 +2388,14 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) 
as _,
         )];
 
-        let join = join(left, right, on, &JoinType::Left, false).unwrap();
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Left,
+            NullEquality::NullEqualsNothing,
+        )
+        .unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -2394,7 +2438,14 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
 
-        let join = join(left, right, on, &JoinType::Full, false).unwrap();
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Full,
+            NullEquality::NullEqualsNothing,
+        )
+        .unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -2435,7 +2486,14 @@ mod tests {
         )];
         let schema = right.schema();
         let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, 
None).unwrap();
-        let join = join(left, right, on, &JoinType::Left, false).unwrap();
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Left,
+            NullEquality::NullEqualsNothing,
+        )
+        .unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -2472,7 +2530,14 @@ mod tests {
         )];
         let schema = right.schema();
         let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, 
None).unwrap();
-        let join = join(left, right, on, &JoinType::Full, false).unwrap();
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Full,
+            NullEquality::NullEqualsNothing,
+        )
+        .unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -2517,7 +2582,7 @@ mod tests {
             Arc::clone(&right),
             on.clone(),
             &JoinType::Left,
-            false,
+            NullEquality::NullEqualsNothing,
             task_ctx,
         )
         .await?;
@@ -2562,7 +2627,7 @@ mod tests {
             Arc::clone(&right),
             on.clone(),
             &JoinType::Left,
-            false,
+            NullEquality::NullEqualsNothing,
             task_ctx,
         )
         .await?;
@@ -2615,7 +2680,13 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let join = join(left, right, on, &JoinType::LeftSemi, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::LeftSemi,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1"]);
@@ -2677,7 +2748,7 @@ mod tests {
             on.clone(),
             filter,
             &JoinType::LeftSemi,
-            false,
+            NullEquality::NullEqualsNothing,
         )?;
 
         let columns_header = columns(&join.schema());
@@ -2710,7 +2781,14 @@ mod tests {
             Arc::new(intermediate_schema),
         );
 
-        let join = join_with_filter(left, right, on, filter, 
&JoinType::LeftSemi, false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::LeftSemi,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns_header = columns(&join.schema());
         assert_eq!(columns_header, vec!["a1", "b1", "c1"]);
@@ -2744,7 +2822,13 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let join = join(left, right, on, &JoinType::RightSemi, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::RightSemi,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a2", "b2", "c2"]);
@@ -2806,7 +2890,7 @@ mod tests {
             on.clone(),
             filter,
             &JoinType::RightSemi,
-            false,
+            NullEquality::NullEqualsNothing,
         )?;
 
         let columns = columns(&join.schema());
@@ -2841,8 +2925,14 @@ mod tests {
             Arc::new(intermediate_schema.clone()),
         );
 
-        let join =
-            join_with_filter(left, right, on, filter, &JoinType::RightSemi, 
false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::RightSemi,
+            NullEquality::NullEqualsNothing,
+        )?;
         let stream = join.execute(0, task_ctx)?;
         let batches = common::collect(stream).await?;
 
@@ -2873,7 +2963,13 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let join = join(left, right, on, &JoinType::LeftAnti, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::LeftAnti,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1"]);
@@ -2932,7 +3028,7 @@ mod tests {
             on.clone(),
             filter,
             &JoinType::LeftAnti,
-            false,
+            NullEquality::NullEqualsNothing,
         )?;
 
         let columns_header = columns(&join.schema());
@@ -2969,7 +3065,14 @@ mod tests {
             Arc::new(intermediate_schema),
         );
 
-        let join = join_with_filter(left, right, on, filter, 
&JoinType::LeftAnti, false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::LeftAnti,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns_header = columns(&join.schema());
         assert_eq!(columns_header, vec!["a1", "b1", "c1"]);
@@ -3006,7 +3109,13 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let join = join(left, right, on, &JoinType::RightAnti, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::RightAnti,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a2", "b2", "c2"]);
@@ -3066,7 +3175,7 @@ mod tests {
             on.clone(),
             filter,
             &JoinType::RightAnti,
-            false,
+            NullEquality::NullEqualsNothing,
         )?;
 
         let columns_header = columns(&join.schema());
@@ -3107,8 +3216,14 @@ mod tests {
             Arc::new(intermediate_schema),
         );
 
-        let join =
-            join_with_filter(left, right, on, filter, &JoinType::RightAnti, 
false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::RightAnti,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns_header = columns(&join.schema());
         assert_eq!(columns_header, vec!["a2", "b2", "c2"]);
@@ -3152,8 +3267,15 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) =
-            join_collect(left, right, on, &JoinType::Right, false, 
task_ctx).await?;
+        let (columns, batches) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Right,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
@@ -3191,9 +3313,15 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) =
-            partitioned_join_collect(left, right, on, &JoinType::Right, false, 
task_ctx)
-                .await?;
+        let (columns, batches) = partitioned_join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Right,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
@@ -3231,7 +3359,13 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
 
-        let join = join(left, right, on, &JoinType::Full, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Full,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -3279,7 +3413,7 @@ mod tests {
             Arc::clone(&right),
             on.clone(),
             &JoinType::LeftMark,
-            false,
+            NullEquality::NullEqualsNothing,
             task_ctx,
         )
         .await?;
@@ -3324,7 +3458,7 @@ mod tests {
             Arc::clone(&right),
             on.clone(),
             &JoinType::LeftMark,
-            false,
+            NullEquality::NullEqualsNothing,
             task_ctx,
         )
         .await?;
@@ -3488,7 +3622,7 @@ mod tests {
             &join_hash_map,
             &[left_keys_values],
             &[right_keys_values],
-            false,
+            NullEquality::NullEqualsNothing,
             &hashes_buffer,
             8192,
             (0, None),
@@ -3524,7 +3658,13 @@ mod tests {
             Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) 
as _,
         )];
 
-        let join = join(left, right, on, &JoinType::Inner, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
@@ -3594,7 +3734,14 @@ mod tests {
         )];
         let filter = prepare_join_filter();
 
-        let join = join_with_filter(left, right, on, filter, &JoinType::Inner, 
false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
@@ -3636,7 +3783,14 @@ mod tests {
         )];
         let filter = prepare_join_filter();
 
-        let join = join_with_filter(left, right, on, filter, &JoinType::Left, 
false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::Left,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
@@ -3681,7 +3835,14 @@ mod tests {
         )];
         let filter = prepare_join_filter();
 
-        let join = join_with_filter(left, right, on, filter, &JoinType::Right, 
false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::Right,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
@@ -3725,7 +3886,14 @@ mod tests {
         )];
         let filter = prepare_join_filter();
 
-        let join = join_with_filter(left, right, on, filter, &JoinType::Full, 
false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            &JoinType::Full,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
@@ -3892,7 +4060,7 @@ mod tests {
                 on.clone(),
                 &join_type,
                 PartitionMode::CollectLeft,
-                false,
+                NullEquality::NullEqualsNothing,
                 Arc::clone(&task_ctx),
             )
             .await?;
@@ -3924,7 +4092,13 @@ mod tests {
             Arc::new(Column::new_with_schema("date", 
&right.schema()).unwrap()) as _,
         )];
 
-        let join = join(left, right, on, &JoinType::Inner, false)?;
+        let join = join(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+        )?;
 
         let task_ctx = Arc::new(TaskContext::default());
         let stream = join.execute(0, task_ctx)?;
@@ -3983,7 +4157,7 @@ mod tests {
                 Arc::clone(&right_input) as Arc<dyn ExecutionPlan>,
                 on.clone(),
                 &join_type,
-                false,
+                NullEquality::NullEqualsNothing,
             )
             .unwrap();
             let task_ctx = Arc::new(TaskContext::default());
@@ -4097,7 +4271,7 @@ mod tests {
                     Arc::clone(&right),
                     on.clone(),
                     &join_type,
-                    false,
+                    NullEquality::NullEqualsNothing,
                 )
                 .unwrap();
 
@@ -4177,7 +4351,7 @@ mod tests {
                 Arc::clone(&right),
                 on.clone(),
                 &join_type,
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             let stream = join.execute(0, task_ctx)?;
@@ -4258,7 +4432,7 @@ mod tests {
                 &join_type,
                 None,
                 PartitionMode::Partitioned,
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             let stream = join.execute(1, task_ctx)?;
@@ -4318,8 +4492,15 @@ mod tests {
             Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) =
-            join_collect(left, right, on, &JoinType::Inner, false, 
task_ctx).await?;
+        let (columns, batches) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["n1", "n2"]);
 
@@ -4355,7 +4536,7 @@ mod tests {
             Arc::clone(&right),
             on.clone(),
             &JoinType::Inner,
-            true,
+            NullEquality::NullEqualsNull,
             Arc::clone(&task_ctx),
         )
         .await?;
@@ -4370,8 +4551,15 @@ mod tests {
                 "#);
         }
 
-        let (_, batches_null_neq) =
-            join_collect(left, right, on, &JoinType::Inner, false, 
task_ctx).await?;
+        let (_, batches_null_neq) = join_collect(
+            left,
+            right,
+            on,
+            &JoinType::Inner,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
+        )
+        .await?;
 
         let expected_null_neq =
             ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"];
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs 
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index 44021c38a7..fcc1107a0e 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -720,7 +720,7 @@ struct NestedLoopJoinStream<T> {
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
     // TODO: support null aware equal
-    // null_equals_null: bool
+    // null_equality: NullEquality,
     /// Join execution metrics
     join_metrics: BuildProbeJoinMetrics,
     /// Cache for join indices calculations
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index 6ab069aaf4..4d635948ed 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -63,7 +63,7 @@ use arrow::error::ArrowError;
 use arrow::ipc::reader::StreamReader;
 use datafusion_common::{
     exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, 
JoinSide,
-    JoinType, Result,
+    JoinType, NullEquality, Result,
 };
 use datafusion_execution::disk_manager::RefCountedTempFile;
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
@@ -145,8 +145,8 @@ pub struct SortMergeJoinExec {
     right_sort_exprs: LexOrdering,
     /// Sort options of join columns used in sorting left and right execution 
plans
     pub sort_options: Vec<SortOptions>,
-    /// If null_equals_null is true, null == null else null != null
-    pub null_equals_null: bool,
+    /// Defines the null equality for the join.
+    pub null_equality: NullEquality,
     /// Cache holding plan properties like equivalences, output partitioning 
etc.
     cache: PlanProperties,
 }
@@ -163,7 +163,7 @@ impl SortMergeJoinExec {
         filter: Option<JoinFilter>,
         join_type: JoinType,
         sort_options: Vec<SortOptions>,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<Self> {
         let left_schema = left.schema();
         let right_schema = right.schema();
@@ -218,7 +218,7 @@ impl SortMergeJoinExec {
             left_sort_exprs,
             right_sort_exprs,
             sort_options,
-            null_equals_null,
+            null_equality,
             cache,
         })
     }
@@ -291,9 +291,9 @@ impl SortMergeJoinExec {
         &self.sort_options
     }
 
-    /// Null equals null
-    pub fn null_equals_null(&self) -> bool {
-        self.null_equals_null
+    /// Null equality
+    pub fn null_equality(&self) -> NullEquality {
+        self.null_equality
     }
 
     /// This function creates the cache object that stores the plan properties 
such as schema, equivalence properties, ordering, partitioning, etc.
@@ -339,7 +339,7 @@ impl SortMergeJoinExec {
             self.filter().as_ref().map(JoinFilter::swap),
             self.join_type().swap(),
             self.sort_options.clone(),
-            self.null_equals_null,
+            self.null_equality,
         )?;
 
         // TODO: OR this condition with having a built-in projection (like
@@ -450,7 +450,7 @@ impl ExecutionPlan for SortMergeJoinExec {
                 self.filter.clone(),
                 self.join_type,
                 self.sort_options.clone(),
-                self.null_equals_null,
+                self.null_equality,
             )?)),
             _ => internal_err!("SortMergeJoin wrong number of children"),
         }
@@ -502,7 +502,7 @@ impl ExecutionPlan for SortMergeJoinExec {
         Ok(Box::pin(SortMergeJoinStream::try_new(
             Arc::clone(&self.schema),
             self.sort_options.clone(),
-            self.null_equals_null,
+            self.null_equality,
             streamed,
             buffered,
             on_streamed,
@@ -591,7 +591,7 @@ impl ExecutionPlan for SortMergeJoinExec {
             self.filter.clone(),
             self.join_type,
             self.sort_options.clone(),
-            self.null_equals_null,
+            self.null_equality,
         )?)))
     }
 }
@@ -844,8 +844,8 @@ struct SortMergeJoinStream {
     // ========================================================================
     /// Output schema
     pub schema: SchemaRef,
-    /// null == null?
-    pub null_equals_null: bool,
+    /// Defines the null equality for the join.
+    pub null_equality: NullEquality,
     /// Sort options of join columns used to sort streamed and buffered data 
stream
     pub sort_options: Vec<SortOptions>,
     /// optional join filter
@@ -1326,7 +1326,7 @@ impl SortMergeJoinStream {
     pub fn try_new(
         schema: SchemaRef,
         sort_options: Vec<SortOptions>,
-        null_equals_null: bool,
+        null_equality: NullEquality,
         streamed: SendableRecordBatchStream,
         buffered: SendableRecordBatchStream,
         on_streamed: Vec<Arc<dyn PhysicalExpr>>,
@@ -1348,7 +1348,7 @@ impl SortMergeJoinStream {
         Ok(Self {
             state: SortMergeJoinState::Init,
             sort_options,
-            null_equals_null,
+            null_equality,
             schema: Arc::clone(&schema),
             streamed_schema: Arc::clone(&streamed_schema),
             buffered_schema,
@@ -1593,7 +1593,7 @@ impl SortMergeJoinStream {
             &self.buffered_data.head_batch().join_arrays,
             self.buffered_data.head_batch().range.start,
             &self.sort_options,
-            self.null_equals_null,
+            self.null_equality,
         )
     }
 
@@ -2434,7 +2434,7 @@ fn compare_join_arrays(
     right_arrays: &[ArrayRef],
     right: usize,
     sort_options: &[SortOptions],
-    null_equals_null: bool,
+    null_equality: NullEquality,
 ) -> Result<Ordering> {
     let mut res = Ordering::Equal;
     for ((left_array, right_array), sort_options) in
@@ -2468,10 +2468,9 @@ fn compare_join_arrays(
                         };
                     }
                     _ => {
-                        res = if null_equals_null {
-                            Ordering::Equal
-                        } else {
-                            Ordering::Less
+                        res = match null_equality {
+                            NullEquality::NullEqualsNothing => Ordering::Less,
+                            NullEquality::NullEqualsNull => Ordering::Equal,
                         };
                     }
                 }
@@ -2597,7 +2596,9 @@ mod tests {
     use arrow::datatypes::{DataType, Field, Schema};
 
     use datafusion_common::JoinType::*;
-    use datafusion_common::{assert_batches_eq, assert_contains, JoinType, 
Result};
+    use datafusion_common::{
+        assert_batches_eq, assert_contains, JoinType, NullEquality, Result,
+    };
     use datafusion_common::{
         test_util::{batches_to_sort_string, batches_to_string},
         JoinSide,
@@ -2722,7 +2723,15 @@ mod tests {
         join_type: JoinType,
     ) -> Result<SortMergeJoinExec> {
         let sort_options = vec![SortOptions::default(); on.len()];
-        SortMergeJoinExec::try_new(left, right, on, None, join_type, 
sort_options, false)
+        SortMergeJoinExec::try_new(
+            left,
+            right,
+            on,
+            None,
+            join_type,
+            sort_options,
+            NullEquality::NullEqualsNothing,
+        )
     }
 
     fn join_with_options(
@@ -2731,7 +2740,7 @@ mod tests {
         on: JoinOn,
         join_type: JoinType,
         sort_options: Vec<SortOptions>,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<SortMergeJoinExec> {
         SortMergeJoinExec::try_new(
             left,
@@ -2740,7 +2749,7 @@ mod tests {
             None,
             join_type,
             sort_options,
-            null_equals_null,
+            null_equality,
         )
     }
 
@@ -2751,7 +2760,7 @@ mod tests {
         filter: JoinFilter,
         join_type: JoinType,
         sort_options: Vec<SortOptions>,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<SortMergeJoinExec> {
         SortMergeJoinExec::try_new(
             left,
@@ -2760,7 +2769,7 @@ mod tests {
             Some(filter),
             join_type,
             sort_options,
-            null_equals_null,
+            null_equality,
         )
     }
 
@@ -2771,7 +2780,15 @@ mod tests {
         join_type: JoinType,
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
         let sort_options = vec![SortOptions::default(); on.len()];
-        join_collect_with_options(left, right, on, join_type, sort_options, 
false).await
+        join_collect_with_options(
+            left,
+            right,
+            on,
+            join_type,
+            sort_options,
+            NullEquality::NullEqualsNothing,
+        )
+        .await
     }
 
     async fn join_collect_with_filter(
@@ -2784,8 +2801,15 @@ mod tests {
         let sort_options = vec![SortOptions::default(); on.len()];
 
         let task_ctx = Arc::new(TaskContext::default());
-        let join =
-            join_with_filter(left, right, on, filter, join_type, sort_options, 
false)?;
+        let join = join_with_filter(
+            left,
+            right,
+            on,
+            filter,
+            join_type,
+            sort_options,
+            NullEquality::NullEqualsNothing,
+        )?;
         let columns = columns(&join.schema());
 
         let stream = join.execute(0, task_ctx)?;
@@ -2799,17 +2823,11 @@ mod tests {
         on: JoinOn,
         join_type: JoinType,
         sort_options: Vec<SortOptions>,
-        null_equals_null: bool,
+        null_equality: NullEquality,
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
         let task_ctx = Arc::new(TaskContext::default());
-        let join = join_with_options(
-            left,
-            right,
-            on,
-            join_type,
-            sort_options,
-            null_equals_null,
-        )?;
+        let join =
+            join_with_options(left, right, on, join_type, sort_options, 
null_equality)?;
         let columns = columns(&join.schema());
 
         let stream = join.execute(0, task_ctx)?;
@@ -3015,7 +3033,7 @@ mod tests {
                 };
                 2
             ],
-            true,
+            NullEquality::NullEqualsNull,
         )
         .await?;
         // The output order is important as SMJ preserves sortedness
@@ -3438,7 +3456,7 @@ mod tests {
                 };
                 2
             ],
-            true,
+            NullEquality::NullEqualsNull,
         )
         .await?;
 
@@ -3715,7 +3733,7 @@ mod tests {
                 };
                 2
             ],
-            true,
+            NullEquality::NullEqualsNull,
         )
         .await?;
 
@@ -4159,7 +4177,7 @@ mod tests {
                 on.clone(),
                 join_type,
                 sort_options.clone(),
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             let stream = join.execute(0, task_ctx)?;
@@ -4240,7 +4258,7 @@ mod tests {
                 on.clone(),
                 join_type,
                 sort_options.clone(),
-                false,
+                NullEquality::NullEqualsNothing,
             )?;
 
             let stream = join.execute(0, task_ctx)?;
@@ -4303,7 +4321,7 @@ mod tests {
                     on.clone(),
                     *join_type,
                     sort_options.clone(),
-                    false,
+                    NullEquality::NullEqualsNothing,
                 )?;
 
                 let stream = join.execute(0, task_ctx)?;
@@ -4325,7 +4343,7 @@ mod tests {
                     on.clone(),
                     *join_type,
                     sort_options.clone(),
-                    false,
+                    NullEquality::NullEqualsNothing,
                 )?;
                 let stream = join.execute(0, task_ctx_no_spill)?;
                 let no_spilled_join_result = 
common::collect(stream).await.unwrap();
@@ -4407,7 +4425,7 @@ mod tests {
                     on.clone(),
                     *join_type,
                     sort_options.clone(),
-                    false,
+                    NullEquality::NullEqualsNothing,
                 )?;
 
                 let stream = join.execute(0, task_ctx)?;
@@ -4428,7 +4446,7 @@ mod tests {
                     on.clone(),
                     *join_type,
                     sort_options.clone(),
-                    false,
+                    NullEquality::NullEqualsNothing,
                 )?;
                 let stream = join.execute(0, task_ctx_no_spill)?;
                 let no_spilled_join_result = 
common::collect(stream).await.unwrap();
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 84575acea5..d540b6d2a3 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -67,7 +67,9 @@ use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::hash_utils::create_hashes;
 use datafusion_common::utils::bisect;
-use datafusion_common::{internal_err, plan_err, HashSet, JoinSide, JoinType, 
Result};
+use datafusion_common::{
+    internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, Result,
+};
 use datafusion_execution::memory_pool::MemoryConsumer;
 use datafusion_execution::TaskContext;
 use datafusion_expr::interval_arithmetic::Interval;
@@ -185,8 +187,8 @@ pub struct SymmetricHashJoinExec {
     metrics: ExecutionPlanMetricsSet,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
-    /// If null_equals_null is true, null == null else null != null
-    pub(crate) null_equals_null: bool,
+    /// Defines the null equality for the join.
+    pub(crate) null_equality: NullEquality,
     /// Left side sort expression(s)
     pub(crate) left_sort_exprs: Option<LexOrdering>,
     /// Right side sort expression(s)
@@ -211,7 +213,7 @@ impl SymmetricHashJoinExec {
         on: JoinOn,
         filter: Option<JoinFilter>,
         join_type: &JoinType,
-        null_equals_null: bool,
+        null_equality: NullEquality,
         left_sort_exprs: Option<LexOrdering>,
         right_sort_exprs: Option<LexOrdering>,
         mode: StreamJoinPartitionMode,
@@ -246,7 +248,7 @@ impl SymmetricHashJoinExec {
             random_state,
             metrics: ExecutionPlanMetricsSet::new(),
             column_indices,
-            null_equals_null,
+            null_equality,
             left_sort_exprs,
             right_sort_exprs,
             mode,
@@ -310,9 +312,9 @@ impl SymmetricHashJoinExec {
         &self.join_type
     }
 
-    /// Get null_equals_null
-    pub fn null_equals_null(&self) -> bool {
-        self.null_equals_null
+    /// Get null_equality
+    pub fn null_equality(&self) -> NullEquality {
+        self.null_equality
     }
 
     /// Get partition mode
@@ -456,7 +458,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
             self.on.clone(),
             self.filter.clone(),
             &self.join_type,
-            self.null_equals_null,
+            self.null_equality,
             self.left_sort_exprs.clone(),
             self.right_sort_exprs.clone(),
             self.mode,
@@ -545,7 +547,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
                 graph,
                 left_sorted_filter_expr,
                 right_sorted_filter_expr,
-                null_equals_null: self.null_equals_null,
+                null_equality: self.null_equality,
                 state: SHJStreamState::PullRight,
                 reservation,
                 batch_transformer: BatchSplitter::new(batch_size),
@@ -565,7 +567,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
                 graph,
                 left_sorted_filter_expr,
                 right_sorted_filter_expr,
-                null_equals_null: self.null_equals_null,
+                null_equality: self.null_equality,
                 state: SHJStreamState::PullRight,
                 reservation,
                 batch_transformer: NoopBatchTransformer::new(),
@@ -637,7 +639,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
             new_on,
             new_filter,
             self.join_type(),
-            self.null_equals_null(),
+            self.null_equality(),
             self.right().output_ordering().cloned(),
             self.left().output_ordering().cloned(),
             self.partition_mode(),
@@ -671,8 +673,8 @@ struct SymmetricHashJoinStream<T> {
     right_sorted_filter_expr: Option<SortedFilterExpr>,
     /// Random state used for hashing initialization
     random_state: RandomState,
-    /// If null_equals_null is true, null == null else null != null
-    null_equals_null: bool,
+    /// Defines the null equality for the join.
+    null_equality: NullEquality,
     /// Metrics
     metrics: StreamJoinMetrics,
     /// Memory reservation
@@ -934,7 +936,7 @@ pub(crate) fn build_side_determined_results(
 /// * `probe_batch` - The second record batch to be joined.
 /// * `column_indices` - An array of columns to be selected for the result of 
the join.
 /// * `random_state` - The random state for the join.
-/// * `null_equals_null` - A boolean indicating whether NULL values should be 
treated as equal when joining.
+/// * `null_equality` - Indicates whether NULL values should be treated as 
equal when joining.
 ///
 /// # Returns
 ///
@@ -950,7 +952,7 @@ pub(crate) fn join_with_probe_batch(
     probe_batch: &RecordBatch,
     column_indices: &[ColumnIndex],
     random_state: &RandomState,
-    null_equals_null: bool,
+    null_equality: NullEquality,
 ) -> Result<Option<RecordBatch>> {
     if build_hash_joiner.input_buffer.num_rows() == 0 || 
probe_batch.num_rows() == 0 {
         return Ok(None);
@@ -962,7 +964,7 @@ pub(crate) fn join_with_probe_batch(
         &build_hash_joiner.on,
         &probe_hash_joiner.on,
         random_state,
-        null_equals_null,
+        null_equality,
         &mut build_hash_joiner.hashes_buffer,
         Some(build_hash_joiner.deleted_offset),
     )?;
@@ -1028,7 +1030,7 @@ pub(crate) fn join_with_probe_batch(
 /// * `build_on` - An array of columns on which the join will be performed. 
The columns are from the build side of the join.
 /// * `probe_on` - An array of columns on which the join will be performed. 
The columns are from the probe side of the join.
 /// * `random_state` - The random state for the join.
-/// * `null_equals_null` - A boolean indicating whether NULL values should be 
treated as equal when joining.
+/// * `null_equality` - Indicates whether NULL values should be treated as 
equal when joining.
 /// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
 /// * `deleted_offset` - deleted offset for build side data.
 ///
@@ -1044,7 +1046,7 @@ fn lookup_join_hashmap(
     build_on: &[PhysicalExprRef],
     probe_on: &[PhysicalExprRef],
     random_state: &RandomState,
-    null_equals_null: bool,
+    null_equality: NullEquality,
     hashes_buffer: &mut Vec<u64>,
     deleted_offset: Option<usize>,
 ) -> Result<(UInt64Array, UInt32Array)> {
@@ -1105,7 +1107,7 @@ fn lookup_join_hashmap(
         &probe_indices,
         &build_join_values,
         &keys_values,
-        null_equals_null,
+        null_equality,
     )?;
 
     Ok((build_indices, probe_indices))
@@ -1602,7 +1604,7 @@ impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
         size += size_of_val(&self.left_sorted_filter_expr);
         size += size_of_val(&self.right_sorted_filter_expr);
         size += size_of_val(&self.random_state);
-        size += size_of_val(&self.null_equals_null);
+        size += size_of_val(&self.null_equality);
         size += size_of_val(&self.metrics);
         size
     }
@@ -1657,7 +1659,7 @@ impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
             &probe_batch,
             &self.column_indices,
             &self.random_state,
-            self.null_equals_null,
+            self.null_equality,
         )?;
         // Increment the offset for the probe hash joiner:
         probe_hash_joiner.offset += probe_batch.num_rows();
@@ -1813,12 +1815,18 @@ mod tests {
             on.clone(),
             filter.clone(),
             &join_type,
-            false,
+            NullEquality::NullEqualsNothing,
             Arc::clone(&task_ctx),
         )
         .await?;
         let second_batches = partitioned_hash_join_with_filter(
-            left, right, on, filter, &join_type, false, task_ctx,
+            left,
+            right,
+            on,
+            filter,
+            &join_type,
+            NullEquality::NullEqualsNothing,
+            task_ctx,
         )
         .await?;
         compare_batches(&first_batches, &second_batches);
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs 
b/datafusion/physical-plan/src/joins/test_utils.rs
index cbabd7cb45..ea893cc933 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -33,7 +33,7 @@ use arrow::array::{
 };
 use arrow::datatypes::{DataType, Schema};
 use arrow::util::pretty::pretty_format_batches;
-use datafusion_common::{Result, ScalarValue};
+use datafusion_common::{NullEquality, Result, ScalarValue};
 use datafusion_execution::TaskContext;
 use datafusion_expr::{JoinType, Operator};
 use datafusion_physical_expr::expressions::{binary, cast, col, lit};
@@ -74,7 +74,7 @@ pub async fn partitioned_sym_join_with_filter(
     on: JoinOn,
     filter: Option<JoinFilter>,
     join_type: &JoinType,
-    null_equals_null: bool,
+    null_equality: NullEquality,
     context: Arc<TaskContext>,
 ) -> Result<Vec<RecordBatch>> {
     let partition_count = 4;
@@ -101,7 +101,7 @@ pub async fn partitioned_sym_join_with_filter(
         on,
         filter,
         join_type,
-        null_equals_null,
+        null_equality,
         left.output_ordering().cloned(),
         right.output_ordering().cloned(),
         StreamJoinPartitionMode::Partitioned,
@@ -128,7 +128,7 @@ pub async fn partitioned_hash_join_with_filter(
     on: JoinOn,
     filter: Option<JoinFilter>,
     join_type: &JoinType,
-    null_equals_null: bool,
+    null_equality: NullEquality,
     context: Arc<TaskContext>,
 ) -> Result<Vec<RecordBatch>> {
     let partition_count = 4;
@@ -151,7 +151,7 @@ pub async fn partitioned_hash_join_with_filter(
         join_type,
         None,
         PartitionMode::Partitioned,
-        null_equals_null,
+        null_equality,
     )?);
 
     let mut batches = vec![];
diff --git a/datafusion/proto-common/proto/datafusion_common.proto 
b/datafusion/proto-common/proto/datafusion_common.proto
index 9eab33928a..81fc9cceb7 100644
--- a/datafusion/proto-common/proto/datafusion_common.proto
+++ b/datafusion/proto-common/proto/datafusion_common.proto
@@ -93,6 +93,11 @@ enum JoinConstraint {
   USING = 1;
 }
 
+enum NullEquality {
+  NULL_EQUALS_NOTHING = 0;
+  NULL_EQUALS_NULL = 1;
+}
+
 message AvroOptions {}
 message ArrowOptions {}
 
diff --git a/datafusion/proto-common/src/generated/pbjson.rs 
b/datafusion/proto-common/src/generated/pbjson.rs
index 0c593a36b8..c3b6686df0 100644
--- a/datafusion/proto-common/src/generated/pbjson.rs
+++ b/datafusion/proto-common/src/generated/pbjson.rs
@@ -4418,6 +4418,77 @@ impl<'de> serde::Deserialize<'de> for NdJsonFormat {
         deserializer.deserialize_struct("datafusion_common.NdJsonFormat", 
FIELDS, GeneratedVisitor)
     }
 }
+impl serde::Serialize for NullEquality {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        let variant = match self {
+            Self::NullEqualsNothing => "NULL_EQUALS_NOTHING",
+            Self::NullEqualsNull => "NULL_EQUALS_NULL",
+        };
+        serializer.serialize_str(variant)
+    }
+}
+impl<'de> serde::Deserialize<'de> for NullEquality {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "NULL_EQUALS_NOTHING",
+            "NULL_EQUALS_NULL",
+        ];
+
+        struct GeneratedVisitor;
+
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = NullEquality;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                write!(formatter, "expected one of: {:?}", &FIELDS)
+            }
+
+            fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, 
E>
+            where
+                E: serde::de::Error,
+            {
+                i32::try_from(v)
+                    .ok()
+                    .and_then(|x| x.try_into().ok())
+                    .ok_or_else(|| {
+                        
serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self)
+                    })
+            }
+
+            fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value, 
E>
+            where
+                E: serde::de::Error,
+            {
+                i32::try_from(v)
+                    .ok()
+                    .and_then(|x| x.try_into().ok())
+                    .ok_or_else(|| {
+                        
serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self)
+                    })
+            }
+
+            fn visit_str<E>(self, value: &str) -> 
std::result::Result<Self::Value, E>
+            where
+                E: serde::de::Error,
+            {
+                match value {
+                    "NULL_EQUALS_NOTHING" => 
Ok(NullEquality::NullEqualsNothing),
+                    "NULL_EQUALS_NULL" => Ok(NullEquality::NullEqualsNull),
+                    _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
+                }
+            }
+        }
+        deserializer.deserialize_any(GeneratedVisitor)
+    }
+}
 impl serde::Serialize for ParquetColumnOptions {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto-common/src/generated/prost.rs 
b/datafusion/proto-common/src/generated/prost.rs
index c051dd00f7..411d72af4c 100644
--- a/datafusion/proto-common/src/generated/prost.rs
+++ b/datafusion/proto-common/src/generated/prost.rs
@@ -970,6 +970,32 @@ impl JoinConstraint {
 }
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, 
::prost::Enumeration)]
 #[repr(i32)]
+pub enum NullEquality {
+    NullEqualsNothing = 0,
+    NullEqualsNull = 1,
+}
+impl NullEquality {
+    /// String value of the enum field names used in the ProtoBuf definition.
+    ///
+    /// The values are not transformed in any way and thus are considered 
stable
+    /// (if the ProtoBuf definition does not change) and safe for programmatic 
use.
+    pub fn as_str_name(&self) -> &'static str {
+        match self {
+            Self::NullEqualsNothing => "NULL_EQUALS_NOTHING",
+            Self::NullEqualsNull => "NULL_EQUALS_NULL",
+        }
+    }
+    /// Creates an enum from field names used in the ProtoBuf definition.
+    pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
+        match value {
+            "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing),
+            "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull),
+            _ => None,
+        }
+    }
+}
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, 
::prost::Enumeration)]
+#[repr(i32)]
 pub enum TimeUnit {
     Second = 0,
     Millisecond = 1,
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 76cd8c9118..1e1f91e07e 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -243,7 +243,7 @@ message JoinNode {
   datafusion_common.JoinConstraint join_constraint = 4;
   repeated LogicalExprNode left_join_key = 5;
   repeated LogicalExprNode right_join_key = 6;
-  bool null_equals_null = 7;
+  datafusion_common.NullEquality null_equality = 7;
   LogicalExprNode filter = 8;
 }
 
@@ -1051,7 +1051,7 @@ message HashJoinExecNode {
   repeated JoinOn on = 3;
   datafusion_common.JoinType join_type = 4;
   PartitionMode partition_mode = 6;
-  bool null_equals_null = 7;
+  datafusion_common.NullEquality null_equality = 7;
   JoinFilter filter = 8;
   repeated uint32 projection = 9;
 }
@@ -1067,7 +1067,7 @@ message SymmetricHashJoinExecNode {
   repeated JoinOn on = 3;
   datafusion_common.JoinType join_type = 4;
   StreamPartitionMode partition_mode = 6;
-  bool null_equals_null = 7;
+  datafusion_common.NullEquality null_equality = 7;
   JoinFilter filter = 8;
   repeated PhysicalSortExprNode left_sort_exprs = 9;
   repeated PhysicalSortExprNode right_sort_exprs = 10;
diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs 
b/datafusion/proto/src/generated/datafusion_proto_common.rs
index c051dd00f7..411d72af4c 100644
--- a/datafusion/proto/src/generated/datafusion_proto_common.rs
+++ b/datafusion/proto/src/generated/datafusion_proto_common.rs
@@ -970,6 +970,32 @@ impl JoinConstraint {
 }
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, 
::prost::Enumeration)]
 #[repr(i32)]
+pub enum NullEquality {
+    NullEqualsNothing = 0,
+    NullEqualsNull = 1,
+}
+impl NullEquality {
+    /// String value of the enum field names used in the ProtoBuf definition.
+    ///
+    /// The values are not transformed in any way and thus are considered 
stable
+    /// (if the ProtoBuf definition does not change) and safe for programmatic 
use.
+    pub fn as_str_name(&self) -> &'static str {
+        match self {
+            Self::NullEqualsNothing => "NULL_EQUALS_NOTHING",
+            Self::NullEqualsNull => "NULL_EQUALS_NULL",
+        }
+    }
+    /// Creates an enum from field names used in the ProtoBuf definition.
+    pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
+        match value {
+            "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing),
+            "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull),
+            _ => None,
+        }
+    }
+}
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, 
::prost::Enumeration)]
+#[repr(i32)]
 pub enum TimeUnit {
     Second = 0,
     Millisecond = 1,
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 8a62be84ec..02a1cc70ee 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -6851,7 +6851,7 @@ impl serde::Serialize for HashJoinExecNode {
         if self.partition_mode != 0 {
             len += 1;
         }
-        if self.null_equals_null {
+        if self.null_equality != 0 {
             len += 1;
         }
         if self.filter.is_some() {
@@ -6880,8 +6880,10 @@ impl serde::Serialize for HashJoinExecNode {
                 .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.partition_mode)))?;
             struct_ser.serialize_field("partitionMode", &v)?;
         }
-        if self.null_equals_null {
-            struct_ser.serialize_field("nullEqualsNull", 
&self.null_equals_null)?;
+        if self.null_equality != 0 {
+            let v = 
super::datafusion_common::NullEquality::try_from(self.null_equality)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.null_equality)))?;
+            struct_ser.serialize_field("nullEquality", &v)?;
         }
         if let Some(v) = self.filter.as_ref() {
             struct_ser.serialize_field("filter", v)?;
@@ -6906,8 +6908,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode {
             "joinType",
             "partition_mode",
             "partitionMode",
-            "null_equals_null",
-            "nullEqualsNull",
+            "null_equality",
+            "nullEquality",
             "filter",
             "projection",
         ];
@@ -6919,7 +6921,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode {
             On,
             JoinType,
             PartitionMode,
-            NullEqualsNull,
+            NullEquality,
             Filter,
             Projection,
         }
@@ -6948,7 +6950,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode {
                             "on" => Ok(GeneratedField::On),
                             "joinType" | "join_type" => 
Ok(GeneratedField::JoinType),
                             "partitionMode" | "partition_mode" => 
Ok(GeneratedField::PartitionMode),
-                            "nullEqualsNull" | "null_equals_null" => 
Ok(GeneratedField::NullEqualsNull),
+                            "nullEquality" | "null_equality" => 
Ok(GeneratedField::NullEquality),
                             "filter" => Ok(GeneratedField::Filter),
                             "projection" => Ok(GeneratedField::Projection),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
@@ -6975,7 +6977,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode {
                 let mut on__ = None;
                 let mut join_type__ = None;
                 let mut partition_mode__ = None;
-                let mut null_equals_null__ = None;
+                let mut null_equality__ = None;
                 let mut filter__ = None;
                 let mut projection__ = None;
                 while let Some(k) = map_.next_key()? {
@@ -7010,11 +7012,11 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode {
                             }
                             partition_mode__ = 
Some(map_.next_value::<PartitionMode>()? as i32);
                         }
-                        GeneratedField::NullEqualsNull => {
-                            if null_equals_null__.is_some() {
-                                return 
Err(serde::de::Error::duplicate_field("nullEqualsNull"));
+                        GeneratedField::NullEquality => {
+                            if null_equality__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nullEquality"));
                             }
-                            null_equals_null__ = Some(map_.next_value()?);
+                            null_equality__ = 
Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32);
                         }
                         GeneratedField::Filter => {
                             if filter__.is_some() {
@@ -7039,7 +7041,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode {
                     on: on__.unwrap_or_default(),
                     join_type: join_type__.unwrap_or_default(),
                     partition_mode: partition_mode__.unwrap_or_default(),
-                    null_equals_null: null_equals_null__.unwrap_or_default(),
+                    null_equality: null_equality__.unwrap_or_default(),
                     filter: filter__,
                     projection: projection__.unwrap_or_default(),
                 })
@@ -8475,7 +8477,7 @@ impl serde::Serialize for JoinNode {
         if !self.right_join_key.is_empty() {
             len += 1;
         }
-        if self.null_equals_null {
+        if self.null_equality != 0 {
             len += 1;
         }
         if self.filter.is_some() {
@@ -8504,8 +8506,10 @@ impl serde::Serialize for JoinNode {
         if !self.right_join_key.is_empty() {
             struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?;
         }
-        if self.null_equals_null {
-            struct_ser.serialize_field("nullEqualsNull", 
&self.null_equals_null)?;
+        if self.null_equality != 0 {
+            let v = 
super::datafusion_common::NullEquality::try_from(self.null_equality)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.null_equality)))?;
+            struct_ser.serialize_field("nullEquality", &v)?;
         }
         if let Some(v) = self.filter.as_ref() {
             struct_ser.serialize_field("filter", v)?;
@@ -8530,8 +8534,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
             "leftJoinKey",
             "right_join_key",
             "rightJoinKey",
-            "null_equals_null",
-            "nullEqualsNull",
+            "null_equality",
+            "nullEquality",
             "filter",
         ];
 
@@ -8543,7 +8547,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
             JoinConstraint,
             LeftJoinKey,
             RightJoinKey,
-            NullEqualsNull,
+            NullEquality,
             Filter,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
@@ -8572,7 +8576,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                             "joinConstraint" | "join_constraint" => 
Ok(GeneratedField::JoinConstraint),
                             "leftJoinKey" | "left_join_key" => 
Ok(GeneratedField::LeftJoinKey),
                             "rightJoinKey" | "right_join_key" => 
Ok(GeneratedField::RightJoinKey),
-                            "nullEqualsNull" | "null_equals_null" => 
Ok(GeneratedField::NullEqualsNull),
+                            "nullEquality" | "null_equality" => 
Ok(GeneratedField::NullEquality),
                             "filter" => Ok(GeneratedField::Filter),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
@@ -8599,7 +8603,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                 let mut join_constraint__ = None;
                 let mut left_join_key__ = None;
                 let mut right_join_key__ = None;
-                let mut null_equals_null__ = None;
+                let mut null_equality__ = None;
                 let mut filter__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
@@ -8639,11 +8643,11 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                             }
                             right_join_key__ = Some(map_.next_value()?);
                         }
-                        GeneratedField::NullEqualsNull => {
-                            if null_equals_null__.is_some() {
-                                return 
Err(serde::de::Error::duplicate_field("nullEqualsNull"));
+                        GeneratedField::NullEquality => {
+                            if null_equality__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nullEquality"));
                             }
-                            null_equals_null__ = Some(map_.next_value()?);
+                            null_equality__ = 
Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32);
                         }
                         GeneratedField::Filter => {
                             if filter__.is_some() {
@@ -8660,7 +8664,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                     join_constraint: join_constraint__.unwrap_or_default(),
                     left_join_key: left_join_key__.unwrap_or_default(),
                     right_join_key: right_join_key__.unwrap_or_default(),
-                    null_equals_null: null_equals_null__.unwrap_or_default(),
+                    null_equality: null_equality__.unwrap_or_default(),
                     filter: filter__,
                 })
             }
@@ -20041,7 +20045,7 @@ impl serde::Serialize for SymmetricHashJoinExecNode {
         if self.partition_mode != 0 {
             len += 1;
         }
-        if self.null_equals_null {
+        if self.null_equality != 0 {
             len += 1;
         }
         if self.filter.is_some() {
@@ -20073,8 +20077,10 @@ impl serde::Serialize for SymmetricHashJoinExecNode {
                 .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.partition_mode)))?;
             struct_ser.serialize_field("partitionMode", &v)?;
         }
-        if self.null_equals_null {
-            struct_ser.serialize_field("nullEqualsNull", 
&self.null_equals_null)?;
+        if self.null_equality != 0 {
+            let v = 
super::datafusion_common::NullEquality::try_from(self.null_equality)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.null_equality)))?;
+            struct_ser.serialize_field("nullEquality", &v)?;
         }
         if let Some(v) = self.filter.as_ref() {
             struct_ser.serialize_field("filter", v)?;
@@ -20102,8 +20108,8 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
             "joinType",
             "partition_mode",
             "partitionMode",
-            "null_equals_null",
-            "nullEqualsNull",
+            "null_equality",
+            "nullEquality",
             "filter",
             "left_sort_exprs",
             "leftSortExprs",
@@ -20118,7 +20124,7 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
             On,
             JoinType,
             PartitionMode,
-            NullEqualsNull,
+            NullEquality,
             Filter,
             LeftSortExprs,
             RightSortExprs,
@@ -20148,7 +20154,7 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                             "on" => Ok(GeneratedField::On),
                             "joinType" | "join_type" => 
Ok(GeneratedField::JoinType),
                             "partitionMode" | "partition_mode" => 
Ok(GeneratedField::PartitionMode),
-                            "nullEqualsNull" | "null_equals_null" => 
Ok(GeneratedField::NullEqualsNull),
+                            "nullEquality" | "null_equality" => 
Ok(GeneratedField::NullEquality),
                             "filter" => Ok(GeneratedField::Filter),
                             "leftSortExprs" | "left_sort_exprs" => 
Ok(GeneratedField::LeftSortExprs),
                             "rightSortExprs" | "right_sort_exprs" => 
Ok(GeneratedField::RightSortExprs),
@@ -20176,7 +20182,7 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                 let mut on__ = None;
                 let mut join_type__ = None;
                 let mut partition_mode__ = None;
-                let mut null_equals_null__ = None;
+                let mut null_equality__ = None;
                 let mut filter__ = None;
                 let mut left_sort_exprs__ = None;
                 let mut right_sort_exprs__ = None;
@@ -20212,11 +20218,11 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                             }
                             partition_mode__ = 
Some(map_.next_value::<StreamPartitionMode>()? as i32);
                         }
-                        GeneratedField::NullEqualsNull => {
-                            if null_equals_null__.is_some() {
-                                return 
Err(serde::de::Error::duplicate_field("nullEqualsNull"));
+                        GeneratedField::NullEquality => {
+                            if null_equality__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nullEquality"));
                             }
-                            null_equals_null__ = Some(map_.next_value()?);
+                            null_equality__ = 
Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32);
                         }
                         GeneratedField::Filter => {
                             if filter__.is_some() {
@@ -20244,7 +20250,7 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                     on: on__.unwrap_or_default(),
                     join_type: join_type__.unwrap_or_default(),
                     partition_mode: partition_mode__.unwrap_or_default(),
-                    null_equals_null: null_equals_null__.unwrap_or_default(),
+                    null_equality: null_equality__.unwrap_or_default(),
                     filter: filter__,
                     left_sort_exprs: left_sort_exprs__.unwrap_or_default(),
                     right_sort_exprs: right_sort_exprs__.unwrap_or_default(),
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 3e3a04051f..c1f8fa61f3 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -369,8 +369,8 @@ pub struct JoinNode {
     pub left_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>,
     #[prost(message, repeated, tag = "6")]
     pub right_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>,
-    #[prost(bool, tag = "7")]
-    pub null_equals_null: bool,
+    #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")]
+    pub null_equality: i32,
     #[prost(message, optional, tag = "8")]
     pub filter: ::core::option::Option<LogicalExprNode>,
 }
@@ -1592,8 +1592,8 @@ pub struct HashJoinExecNode {
     pub join_type: i32,
     #[prost(enumeration = "PartitionMode", tag = "6")]
     pub partition_mode: i32,
-    #[prost(bool, tag = "7")]
-    pub null_equals_null: bool,
+    #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")]
+    pub null_equality: i32,
     #[prost(message, optional, tag = "8")]
     pub filter: ::core::option::Option<JoinFilter>,
     #[prost(uint32, repeated, tag = "9")]
@@ -1611,8 +1611,8 @@ pub struct SymmetricHashJoinExecNode {
     pub join_type: i32,
     #[prost(enumeration = "StreamPartitionMode", tag = "6")]
     pub partition_mode: i32,
-    #[prost(bool, tag = "7")]
-    pub null_equals_null: bool,
+    #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")]
+    pub null_equality: i32,
     #[prost(message, optional, tag = "8")]
     pub filter: ::core::option::Option<JoinFilter>,
     #[prost(message, repeated, tag = "9")]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 162ed7ae25..66ef0ebfe3 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -19,8 +19,8 @@ use std::sync::Arc;
 
 use datafusion::execution::registry::FunctionRegistry;
 use datafusion_common::{
-    exec_datafusion_err, internal_err, plan_datafusion_err, 
RecursionUnnestOption,
-    Result, ScalarValue, TableReference, UnnestOptions,
+    exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality,
+    RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions,
 };
 use datafusion_expr::dml::InsertOp;
 use datafusion_expr::expr::{Alias, Placeholder, Sort};
@@ -219,6 +219,15 @@ impl From<protobuf::JoinConstraint> for JoinConstraint {
     }
 }
 
+impl From<protobuf::NullEquality> for NullEquality {
+    fn from(t: protobuf::NullEquality) -> Self {
+        match t {
+            protobuf::NullEquality::NullEqualsNothing => 
NullEquality::NullEqualsNothing,
+            protobuf::NullEquality::NullEqualsNull => 
NullEquality::NullEqualsNull,
+        }
+    }
+}
+
 impl From<protobuf::dml_node::Type> for WriteOp {
     fn from(t: protobuf::dml_node::Type) -> Self {
         match t {
diff --git a/datafusion/proto/src/logical_plan/mod.rs 
b/datafusion/proto/src/logical_plan/mod.rs
index d934b24dc3..1acf1ee27b 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -1327,7 +1327,7 @@ impl AsLogicalPlan for LogicalPlanNode {
                 filter,
                 join_type,
                 join_constraint,
-                null_equals_null,
+                null_equality,
                 ..
             }) => {
                 let left: LogicalPlanNode = 
LogicalPlanNode::try_from_logical_plan(
@@ -1352,6 +1352,8 @@ impl AsLogicalPlan for LogicalPlanNode {
                 let join_type: protobuf::JoinType = 
join_type.to_owned().into();
                 let join_constraint: protobuf::JoinConstraint =
                     join_constraint.to_owned().into();
+                let null_equality: protobuf::NullEquality =
+                    null_equality.to_owned().into();
                 let filter = filter
                     .as_ref()
                     .map(|e| serialize_expr(e, extension_codec))
@@ -1365,7 +1367,7 @@ impl AsLogicalPlan for LogicalPlanNode {
                             join_constraint: join_constraint.into(),
                             left_join_key,
                             right_join_key,
-                            null_equals_null: *null_equals_null,
+                            null_equality: null_equality.into(),
                             filter,
                         },
                     ))),
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index f7f303fefe..b14ad7aadf 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -21,7 +21,7 @@
 
 use std::collections::HashMap;
 
-use datafusion_common::{TableReference, UnnestOptions};
+use datafusion_common::{NullEquality, TableReference, UnnestOptions};
 use datafusion_expr::dml::InsertOp;
 use datafusion_expr::expr::{
     self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, 
GroupingSet, InList,
@@ -699,6 +699,15 @@ impl From<JoinConstraint> for protobuf::JoinConstraint {
     }
 }
 
+impl From<NullEquality> for protobuf::NullEquality {
+    fn from(t: NullEquality) -> Self {
+        match t {
+            NullEquality::NullEqualsNothing => 
protobuf::NullEquality::NullEqualsNothing,
+            NullEquality::NullEqualsNull => 
protobuf::NullEquality::NullEqualsNull,
+        }
+    }
+}
+
 impl From<&WriteOp> for protobuf::dml_node::Type {
     fn from(t: &WriteOp) -> Self {
         match t {
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index e2c391d044..3d541f54fe 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1152,6 +1152,13 @@ impl protobuf::PhysicalPlanNode {
                     hashjoin.join_type
                 ))
             })?;
+        let null_equality = 
protobuf::NullEquality::try_from(hashjoin.null_equality)
+            .map_err(|_| {
+                proto_error(format!(
+                    "Received a HashJoinNode message with unknown NullEquality 
{}",
+                    hashjoin.null_equality
+                ))
+            })?;
         let filter = hashjoin
             .filter
             .as_ref()
@@ -1220,7 +1227,7 @@ impl protobuf::PhysicalPlanNode {
             &join_type.into(),
             projection,
             partition_mode,
-            hashjoin.null_equals_null,
+            null_equality.into(),
         )?))
     }
 
@@ -1263,6 +1270,13 @@ impl protobuf::PhysicalPlanNode {
                     sym_join.join_type
                 ))
             })?;
+        let null_equality = 
protobuf::NullEquality::try_from(sym_join.null_equality)
+            .map_err(|_| {
+                proto_error(format!(
+                    "Received a SymmetricHashJoin message with unknown 
NullEquality {}",
+                    sym_join.null_equality
+                ))
+            })?;
         let filter = sym_join
             .filter
             .as_ref()
@@ -1339,7 +1353,7 @@ impl protobuf::PhysicalPlanNode {
             on,
             filter,
             &join_type.into(),
-            sym_join.null_equals_null,
+            null_equality.into(),
             left_sort_exprs,
             right_sort_exprs,
             partition_mode,
@@ -1950,6 +1964,7 @@ impl protobuf::PhysicalPlanNode {
             })
             .collect::<Result<_>>()?;
         let join_type: protobuf::JoinType = exec.join_type().to_owned().into();
+        let null_equality: protobuf::NullEquality = 
exec.null_equality().into();
         let filter = exec
             .filter()
             .as_ref()
@@ -1990,7 +2005,7 @@ impl protobuf::PhysicalPlanNode {
                     on,
                     join_type: join_type.into(),
                     partition_mode: partition_mode.into(),
-                    null_equals_null: exec.null_equals_null(),
+                    null_equality: null_equality.into(),
                     filter,
                     projection: exec.projection.as_ref().map_or_else(Vec::new, 
|v| {
                         v.iter().map(|x| *x as u32).collect::<Vec<u32>>()
@@ -2025,6 +2040,7 @@ impl protobuf::PhysicalPlanNode {
             })
             .collect::<Result<_>>()?;
         let join_type: protobuf::JoinType = exec.join_type().to_owned().into();
+        let null_equality: protobuf::NullEquality = 
exec.null_equality().into();
         let filter = exec
             .filter()
             .as_ref()
@@ -2108,7 +2124,7 @@ impl protobuf::PhysicalPlanNode {
                     on,
                     join_type: join_type.into(),
                     partition_mode: partition_mode.into(),
-                    null_equals_null: exec.null_equals_null(),
+                    null_equality: null_equality.into(),
                     left_sort_exprs,
                     right_sort_exprs,
                     filter,
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index ab419f3713..43f9942a0a 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -98,7 +98,7 @@ use 
datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::parsers::CompressionTypeVariant;
 use datafusion_common::stats::Precision;
 use datafusion_common::{
-    internal_err, not_impl_err, DataFusionError, Result, UnnestOptions,
+    internal_err, not_impl_err, DataFusionError, NullEquality, Result, 
UnnestOptions,
 };
 use datafusion_expr::{
     Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, 
ScalarUDF,
@@ -268,7 +268,7 @@ fn roundtrip_hash_join() -> Result<()> {
                 join_type,
                 None,
                 *partition_mode,
-                false,
+                NullEquality::NullEqualsNothing,
             )?))?;
         }
     }
@@ -1494,7 +1494,7 @@ fn roundtrip_sym_hash_join() -> Result<()> {
                         on.clone(),
                         None,
                         join_type,
-                        false,
+                        NullEquality::NullEqualsNothing,
                         left_order.clone(),
                         right_order,
                         *partition_mode,
diff --git a/datafusion/sqllogictest/test_files/joins.slt 
b/datafusion/sqllogictest/test_files/joins.slt
index ccecb94943..3be5c1b1c3 100644
--- a/datafusion/sqllogictest/test_files/joins.slt
+++ b/datafusion/sqllogictest/test_files/joins.slt
@@ -1373,7 +1373,7 @@ inner join join_t4 on join_t3.s3 = join_t4.s4
 {id: 2} {id: 2}
 
 # join with struct key and nulls
-# Note that intersect or except applies `null_equals_null` as true for Join.
+# Note that intersect or except applies `null_equality` as 
`NullEquality::NullEqualsNull` for Join.
 query ?
 SELECT * FROM join_t3
 EXCEPT
diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs 
b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs
index fab43a5ff4..0cf920dd62 100644
--- a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs
+++ b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs
@@ -17,7 +17,7 @@
 
 use crate::logical_plan::consumer::utils::requalify_sides_if_needed;
 use crate::logical_plan::consumer::SubstraitConsumer;
-use datafusion::common::{not_impl_err, plan_err, Column, JoinType};
+use datafusion::common::{not_impl_err, plan_err, Column, JoinType, 
NullEquality};
 use datafusion::logical_expr::utils::split_conjunction;
 use datafusion::logical_expr::{
     BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator,
@@ -59,19 +59,30 @@ pub async fn from_join_rel(
                 split_eq_and_noneq_join_predicate_with_nulls_equality(&on);
             let (left_cols, right_cols): (Vec<_>, Vec<_>) =
                 itertools::multiunzip(join_ons);
+            let null_equality = if nulls_equal_nulls {
+                NullEquality::NullEqualsNull
+            } else {
+                NullEquality::NullEqualsNothing
+            };
             left.join_detailed(
                 right.build()?,
                 join_type,
                 (left_cols, right_cols),
                 join_filter,
-                nulls_equal_nulls,
+                null_equality,
             )?
             .build()
         }
         None => {
             let on: Vec<String> = vec![];
-            left.join_detailed(right.build()?, join_type, (on.clone(), on), 
None, false)?
-                .build()
+            left.join_detailed(
+                right.build()?,
+                join_type,
+                (on.clone(), on),
+                None,
+                NullEquality::NullEqualsNothing,
+            )?
+            .build()
         }
     }
 }
diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs 
b/datafusion/substrait/src/logical_plan/producer/rel/join.rs
index 65c3e426d2..3dbac636fe 100644
--- a/datafusion/substrait/src/logical_plan/producer/rel/join.rs
+++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs
@@ -16,7 +16,9 @@
 // under the License.
 
 use crate::logical_plan::producer::{make_binary_op_scalar_func, 
SubstraitProducer};
-use datafusion::common::{not_impl_err, DFSchemaRef, JoinConstraint, JoinType};
+use datafusion::common::{
+    not_impl_err, DFSchemaRef, JoinConstraint, JoinType, NullEquality,
+};
 use datafusion::logical_expr::{Expr, Join, Operator};
 use std::sync::Arc;
 use substrait::proto::rel::RelType;
@@ -44,10 +46,9 @@ pub fn from_join(
 
     // map the left and right columns to binary expressions in the form `l = r`
     // build a single expression for the ON condition, such as `l.a = r.a AND 
l.b = r.b`
-    let eq_op = if join.null_equals_null {
-        Operator::IsNotDistinctFrom
-    } else {
-        Operator::Eq
+    let eq_op = match join.null_equality {
+        NullEquality::NullEqualsNothing => Operator::Eq,
+        NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
     };
     let join_on = to_substrait_join_expr(producer, &join.on, eq_op, 
&in_join_schema)?;
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to