This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 11fc52d318 Use dedicated NullEquality enum instead of null_equals_null boolean (#16419) 11fc52d318 is described below commit 11fc52d318ebdce2c683788563b39d76df04de70 Author: Tobias Schwarzinger <tobias.schwarzin...@tuwien.ac.at> AuthorDate: Wed Jun 18 00:52:27 2025 +0200 Use dedicated NullEquality enum instead of null_equals_null boolean (#16419) * Use dedicated NullEquality enum instead of null_equals_null boolean * Fix wrong operator mapping in hash_join * Add an example to the documentation --- datafusion/common/src/lib.rs | 2 + datafusion/common/src/null_equality.rs | 46 +++ datafusion/core/src/physical_planner.rs | 10 +- datafusion/core/tests/execution/infinite_cancel.rs | 10 +- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 6 +- .../tests/physical_optimizer/join_selection.rs | 20 +- .../physical_optimizer/projection_pushdown.rs | 8 +- .../replace_with_order_preserving_variants.rs | 4 +- .../core/tests/physical_optimizer/test_utils.rs | 6 +- datafusion/expr/src/logical_plan/builder.rs | 43 ++- datafusion/expr/src/logical_plan/plan.rs | 50 +-- datafusion/expr/src/logical_plan/tree_node.rs | 8 +- datafusion/optimizer/src/eliminate_cross_join.rs | 42 +-- datafusion/optimizer/src/eliminate_outer_join.rs | 2 +- .../optimizer/src/extract_equijoin_predicate.rs | 6 +- datafusion/optimizer/src/filter_null_join_keys.rs | 5 +- .../physical-optimizer/src/enforce_distribution.rs | 16 +- .../physical-optimizer/src/join_selection.rs | 8 +- datafusion/physical-plan/src/joins/hash_join.rs | 388 +++++++++++++++------ .../physical-plan/src/joins/nested_loop_join.rs | 2 +- .../physical-plan/src/joins/sort_merge_join.rs | 116 +++--- .../physical-plan/src/joins/symmetric_hash_join.rs | 56 +-- datafusion/physical-plan/src/joins/test_utils.rs | 10 +- .../proto-common/proto/datafusion_common.proto | 5 + datafusion/proto-common/src/generated/pbjson.rs | 71 ++++ datafusion/proto-common/src/generated/prost.rs | 26 ++ datafusion/proto/proto/datafusion.proto | 6 +- .../proto/src/generated/datafusion_proto_common.rs | 26 ++ datafusion/proto/src/generated/pbjson.rs | 84 ++--- datafusion/proto/src/generated/prost.rs | 12 +- datafusion/proto/src/logical_plan/from_proto.rs | 13 +- datafusion/proto/src/logical_plan/mod.rs | 6 +- datafusion/proto/src/logical_plan/to_proto.rs | 11 +- datafusion/proto/src/physical_plan/mod.rs | 24 +- .../proto/tests/cases/roundtrip_physical_plan.rs | 6 +- datafusion/sqllogictest/test_files/joins.slt | 2 +- .../src/logical_plan/consumer/rel/join_rel.rs | 19 +- .../src/logical_plan/producer/rel/join.rs | 11 +- 38 files changed, 823 insertions(+), 363 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 7b2c86d397..d89e08c7d4 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -46,6 +46,7 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +mod null_equality; pub mod parsers; pub mod pruning; pub mod rounding; @@ -79,6 +80,7 @@ pub use functional_dependencies::{ }; use hashbrown::hash_map::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use null_equality::NullEquality; pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; diff --git a/datafusion/common/src/null_equality.rs b/datafusion/common/src/null_equality.rs new file mode 100644 index 0000000000..847fb09757 --- /dev/null +++ b/datafusion/common/src/null_equality.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Represents the behavior for null values when evaluating equality. Currently, its primary use +/// case is to define the behavior of joins for null values. +/// +/// # Examples +/// +/// The following table shows the expected equality behavior for `NullEquality`. +/// +/// | A | B | NullEqualsNothing | NullEqualsNull | +/// |------|------|-------------------|----------------| +/// | NULL | NULL | false | true | +/// | NULL | 'b' | false | false | +/// | 'a' | NULL | false | false | +/// | 'a' | 'b' | false | false | +/// +/// # Order +/// +/// The order on this type represents the "restrictiveness" of the behavior. The more restrictive +/// a behavior is, the fewer elements are considered to be equal to null. +/// [NullEquality::NullEqualsNothing] represents the most restrictive behavior. +/// +/// This mirrors the old order with `null_equals_null` booleans, as `false` indicated that +/// `null != null`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum NullEquality { + /// Null is *not* equal to anything (`null != null`) + NullEqualsNothing, + /// Null is equal to null (`null == null`) + NullEqualsNull, +} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a8086b4645..14188f6bf0 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -901,12 +901,10 @@ impl DefaultPhysicalPlanner { on: keys, filter, join_type, - null_equals_null, + null_equality, schema: join_schema, .. }) => { - let null_equals_null = *null_equals_null; - let [physical_left, physical_right] = children.two()?; // If join has expression equijoin keys, add physical projection. @@ -1127,7 +1125,7 @@ impl DefaultPhysicalPlanner { join_filter, *join_type, vec![SortOptions::default(); join_on_len], - null_equals_null, + *null_equality, )?) } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() @@ -1141,7 +1139,7 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::Auto, - null_equals_null, + *null_equality, )?) } else { Arc::new(HashJoinExec::try_new( @@ -1152,7 +1150,7 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + *null_equality, )?) }; diff --git a/datafusion/core/tests/execution/infinite_cancel.rs b/datafusion/core/tests/execution/infinite_cancel.rs index 00c1f6b448..ea35dd367b 100644 --- a/datafusion/core/tests/execution/infinite_cancel.rs +++ b/datafusion/core/tests/execution/infinite_cancel.rs @@ -33,7 +33,7 @@ use datafusion::physical_plan::execution_plan::Boundedness; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use datafusion_common::config::ConfigOptions; -use datafusion_common::{JoinType, ScalarValue}; +use datafusion_common::{JoinType, NullEquality, ScalarValue}; use datafusion_expr_common::operator::Operator::Gt; use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, Literal}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -467,7 +467,7 @@ async fn test_infinite_join_cancel( &JoinType::Inner, None, PartitionMode::CollectLeft, - true, + NullEquality::NullEqualsNull, )?); // 3) Wrap yields under each infinite leaf @@ -550,7 +550,7 @@ async fn test_infinite_join_agg_cancel( &JoinType::Inner, None, PartitionMode::CollectLeft, - true, + NullEquality::NullEqualsNull, )?); // 3) Project only one column (“value” from the left side) because we just want to sum that @@ -714,7 +714,7 @@ async fn test_infinite_hash_join_without_repartition_and_no_agg( /* output64 */ None, // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. PartitionMode::CollectLeft, - /* build_left */ true, + /* build_left */ NullEquality::NullEqualsNull, )?); // 3) Do not apply InsertYieldExec—since there is no aggregation, InsertYieldExec would @@ -796,7 +796,7 @@ async fn test_infinite_sort_merge_join_without_repartition_and_no_agg( /* filter */ None, JoinType::Inner, vec![SortOptions::new(true, false)], // ascending, nulls last - /* null_equal */ true, + /* null_equality */ NullEquality::NullEqualsNull, )?); // 3) Do not apply InsertYieldExec (no aggregation, no repartition → no built-in yields). diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 1a8064ac1e..7250a263d8 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -37,7 +37,7 @@ use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::ScalarValue; +use datafusion_common::{NullEquality, ScalarValue}; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExprRef; @@ -504,7 +504,7 @@ impl JoinFuzzTestCase { self.join_filter(), self.join_type, vec![SortOptions::default(); self.on_columns().len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -521,7 +521,7 @@ impl JoinFuzzTestCase { &self.join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index d8c0c142f7..3477ac7712 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -25,8 +25,8 @@ use std::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::JoinSide; use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, ScalarValue}; +use datafusion_common::{JoinSide, NullEquality}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; @@ -222,7 +222,7 @@ async fn test_join_with_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -284,7 +284,7 @@ async fn test_left_join_no_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -333,7 +333,7 @@ async fn test_join_with_swap_semi() { &join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -408,7 +408,7 @@ async fn test_nested_join_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); let child_schema = child_join.schema(); @@ -425,7 +425,7 @@ async fn test_nested_join_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -464,7 +464,7 @@ async fn test_join_no_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -690,7 +690,7 @@ async fn test_hash_join_swap_on_joins_with_projections( &join_type, Some(projection), PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, )?); let swapped = join @@ -851,7 +851,7 @@ fn check_join_partition_mode( &JoinType::Inner, None, PartitionMode::Auto, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -1498,7 +1498,7 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { &t.initial_join_type, None, t.initial_mode, - false, + NullEquality::NullEqualsNothing, )?) as _; let optimized_join_plan = hash_join_swap_subrule(join, &ConfigOptions::new())?; diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index f2958deb57..1f8aad0f23 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -25,7 +25,7 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; -use datafusion_common::{JoinSide, JoinType, Result, ScalarValue}; +use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -883,7 +883,7 @@ fn test_join_after_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, @@ -997,7 +997,7 @@ fn test_join_after_required_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, @@ -1158,7 +1158,7 @@ fn test_hash_join_after_projection() -> Result<()> { &JoinType::Inner, None, PartitionMode::Auto, - true, + NullEquality::NullEqualsNull, )?); let projection = Arc::new(ProjectionExec::try_new( vec![ diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index f7c134cae3..c9baa9a932 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -30,7 +30,7 @@ use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{assert_contains, Result}; +use datafusion_common::{assert_contains, NullEquality, Result}; use datafusion_common::config::ConfigOptions; use datafusion_datasource::source::DataSourceExec; use datafusion_execution::TaskContext; @@ -1171,7 +1171,7 @@ fn hash_join_exec( &JoinType::Inner, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index ebb623dc61..c91a70989b 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -33,7 +33,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_common::{ColumnStatistics, JoinType, Result, Statistics}; +use datafusion_common::{ColumnStatistics, JoinType, NullEquality, Result, Statistics}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -190,7 +190,7 @@ pub fn sort_merge_join_exec( None, *join_type, vec![SortOptions::default(); join_on.len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -236,7 +236,7 @@ pub fn hash_join_exec( join_type, None, PartitionMode::Partitioned, - true, + NullEquality::NullEqualsNull, )?)) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index fbd1fdadc4..93dd6c2b89 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -56,7 +56,8 @@ use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, + DataFusionError, NullEquality, Result, ScalarValue, TableReference, ToDFSchema, + UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -903,7 +904,13 @@ impl LogicalPlanBuilder { join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>), filter: Option<Expr>, ) -> Result<Self> { - self.join_detailed(right, join_type, join_keys, filter, false) + self.join_detailed( + right, + join_type, + join_keys, + filter, + NullEquality::NullEqualsNothing, + ) } /// Apply a join using the specified expressions. @@ -959,7 +966,7 @@ impl LogicalPlanBuilder { join_type, (Vec::<Column>::new(), Vec::<Column>::new()), filter, - false, + NullEquality::NullEqualsNothing, ) } @@ -987,16 +994,14 @@ impl LogicalPlanBuilder { /// The behavior is the same as [`join`](Self::join) except that it allows /// specifying the null equality behavior. /// - /// If `null_equals_null=true`, rows where both join keys are `null` will be - /// emitted. Otherwise rows where either or both join keys are `null` will be - /// omitted. + /// The `null_equality` dictates how `null` values are joined. pub fn join_detailed( self, right: LogicalPlan, join_type: JoinType, join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>), filter: Option<Expr>, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<Self> { if join_keys.0.len() != join_keys.1.len() { return plan_err!("left_keys and right_keys were not the same length"); @@ -1113,7 +1118,7 @@ impl LogicalPlanBuilder { join_type, join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), - null_equals_null, + null_equality, }))) } @@ -1186,7 +1191,7 @@ impl LogicalPlanBuilder { filters, join_type, JoinConstraint::Using, - false, + NullEquality::NullEqualsNothing, )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1202,7 +1207,7 @@ impl LogicalPlanBuilder { None, JoinType::Inner, JoinConstraint::On, - false, + NullEquality::NullEqualsNothing, )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1340,12 +1345,24 @@ impl LogicalPlanBuilder { .unzip(); if is_all { LogicalPlanBuilder::from(left_plan) - .join_detailed(right_plan, join_type, join_keys, None, true)? + .join_detailed( + right_plan, + join_type, + join_keys, + None, + NullEquality::NullEqualsNull, + )? .build() } else { LogicalPlanBuilder::from(left_plan) .distinct()? - .join_detailed(right_plan, join_type, join_keys, None, true)? + .join_detailed( + right_plan, + join_type, + join_keys, + None, + NullEquality::NullEqualsNull, + )? .build() } } @@ -1423,7 +1440,7 @@ impl LogicalPlanBuilder { filter, join_type, JoinConstraint::On, - false, + NullEquality::NullEqualsNothing, )?; Ok(Self::new(LogicalPlan::Join(join))) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4ac2d182aa..876c14f100 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -56,8 +56,8 @@ use datafusion_common::tree_node::{ use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, ScalarValue, Spans, TableReference, - UnnestOptions, + FunctionalDependencies, NullEquality, ParamValues, Result, ScalarValue, Spans, + TableReference, UnnestOptions, }; use indexmap::IndexSet; @@ -657,7 +657,7 @@ impl LogicalPlan { join_constraint, on, schema: _, - null_equals_null, + null_equality, }) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -678,7 +678,7 @@ impl LogicalPlan { on: new_on, filter, schema: DFSchemaRef::new(schema), - null_equals_null, + null_equality, })) } LogicalPlan::Subquery(_) => Ok(self), @@ -896,7 +896,7 @@ impl LogicalPlan { join_type, join_constraint, on, - null_equals_null, + null_equality, .. }) => { let (left, right) = self.only_two_inputs(inputs)?; @@ -935,7 +935,7 @@ impl LogicalPlan { on: new_on, filter: filter_expr, schema: DFSchemaRef::new(schema), - null_equals_null: *null_equals_null, + null_equality: *null_equality, })) } LogicalPlan::Subquery(Subquery { @@ -3708,8 +3708,8 @@ pub struct Join { pub join_constraint: JoinConstraint, /// The output schema, containing fields from the left and right inputs pub schema: DFSchemaRef, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: bool, + /// Defines the null equality for the join. + pub null_equality: NullEquality, } impl Join { @@ -3726,7 +3726,7 @@ impl Join { /// * `filter` - Optional filter expression (for non-equijoin conditions) /// * `join_type` - Type of join (Inner, Left, Right, etc.) /// * `join_constraint` - Join constraint (On, Using) - /// * `null_equals_null` - Whether NULL = NULL in join comparisons + /// * `null_equality` - How to handle nulls in join comparisons /// /// # Returns /// @@ -3738,7 +3738,7 @@ impl Join { filter: Option<Expr>, join_type: JoinType, join_constraint: JoinConstraint, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<Self> { let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -3750,7 +3750,7 @@ impl Join { join_type, join_constraint, schema: Arc::new(join_schema), - null_equals_null, + null_equality, }) } @@ -3783,7 +3783,7 @@ impl Join { join_type: original_join.join_type, join_constraint: original_join.join_constraint, schema: Arc::new(join_schema), - null_equals_null: original_join.null_equals_null, + null_equality: original_join.null_equality, }) } } @@ -3805,8 +3805,8 @@ impl PartialOrd for Join { pub join_type: &'a JoinType, /// Join constraint pub join_constraint: &'a JoinConstraint, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: &'a bool, + /// The null handling behavior for equalities + pub null_equality: &'a NullEquality, } let comparable_self = ComparableJoin { left: &self.left, @@ -3815,7 +3815,7 @@ impl PartialOrd for Join { filter: &self.filter, join_type: &self.join_type, join_constraint: &self.join_constraint, - null_equals_null: &self.null_equals_null, + null_equality: &self.null_equality, }; let comparable_other = ComparableJoin { left: &other.left, @@ -3824,7 +3824,7 @@ impl PartialOrd for Join { filter: &other.filter, join_type: &other.join_type, join_constraint: &other.join_constraint, - null_equals_null: &other.null_equals_null, + null_equality: &other.null_equality, }; comparable_self.partial_cmp(&comparable_other) } @@ -4895,7 +4895,7 @@ mod tests { join_type: JoinType::Inner, join_constraint: JoinConstraint::On, schema: Arc::new(left_schema.join(&right_schema)?), - null_equals_null: false, + null_equality: NullEquality::NullEqualsNothing, })) } @@ -5006,7 +5006,7 @@ mod tests { Some(col("t1.b").gt(col("t2.b"))), join_type, JoinConstraint::On, - false, + NullEquality::NullEqualsNothing, )?; match join_type { @@ -5116,7 +5116,7 @@ mod tests { assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b")))); assert_eq!(join.join_type, join_type); assert_eq!(join.join_constraint, JoinConstraint::On); - assert!(!join.null_equals_null); + assert_eq!(join.null_equality, NullEquality::NullEqualsNothing); } Ok(()) @@ -5151,7 +5151,7 @@ mod tests { None, JoinType::Inner, JoinConstraint::Using, - false, + NullEquality::NullEqualsNothing, )?; let fields = join.schema.fields(); @@ -5202,7 +5202,7 @@ mod tests { Some(col("t1.value").lt(col("t2.value"))), // Non-equi filter condition JoinType::Inner, JoinConstraint::On, - false, + NullEquality::NullEqualsNothing, )?; let fields = join.schema.fields(); @@ -5251,10 +5251,10 @@ mod tests { None, JoinType::Inner, JoinConstraint::On, - true, + NullEquality::NullEqualsNull, )?; - assert!(join.null_equals_null); + assert_eq!(join.null_equality, NullEquality::NullEqualsNull); } Ok(()) @@ -5293,7 +5293,7 @@ mod tests { Some(col("t1.value").gt(lit(5.0))), join_type, JoinConstraint::On, - false, + NullEquality::NullEqualsNothing, )?; let fields = join.schema.fields(); @@ -5332,7 +5332,7 @@ mod tests { None, JoinType::Inner, JoinConstraint::Using, - false, + NullEquality::NullEqualsNothing, )?; assert_eq!( diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7f5b7e07ed..527248ad39 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -132,7 +132,7 @@ impl TreeNode for LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, @@ -142,7 +142,7 @@ impl TreeNode for LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) }), LogicalPlan::Limit(Limit { skip, fetch, input }) => input @@ -561,7 +561,7 @@ impl LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, @@ -571,7 +571,7 @@ impl LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) }), LogicalPlan::Sort(Sort { expr, input, fetch }) => expr diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index deefaef2c0..ae1d7df46d 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::Result; +use datafusion_common::{NullEquality, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, @@ -89,7 +89,7 @@ impl OptimizerRule for EliminateCrossJoin { let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec<LogicalPlan> = vec![]; let mut all_filters: Vec<Expr> = vec![]; - let mut null_equals_null = false; + let mut null_equality = NullEquality::NullEqualsNothing; let parent_predicate = if let LogicalPlan::Filter(filter) = plan { // if input isn't a join that can potentially be rewritten @@ -115,9 +115,9 @@ impl OptimizerRule for EliminateCrossJoin { input, predicate, .. } = filter; - // Extract null_equals_null setting from the input join + // Extract null_equality setting from the input join if let LogicalPlan::Join(join) = input.as_ref() { - null_equals_null = join.null_equals_null; + null_equality = join.null_equality; } flatten_join_inputs( @@ -133,7 +133,7 @@ impl OptimizerRule for EliminateCrossJoin { match plan { LogicalPlan::Join(Join { join_type: JoinType::Inner, - null_equals_null: original_null_equals_null, + null_equality: original_null_equality, .. }) => { if !can_flatten_join_inputs(&plan) { @@ -145,7 +145,7 @@ impl OptimizerRule for EliminateCrossJoin { &mut all_inputs, &mut all_filters, )?; - null_equals_null = original_null_equals_null; + null_equality = original_null_equality; None } _ => { @@ -164,7 +164,7 @@ impl OptimizerRule for EliminateCrossJoin { &mut all_inputs, &possible_join_keys, &mut all_join_keys, - null_equals_null, + null_equality, )?; } @@ -302,7 +302,7 @@ fn find_inner_join( rights: &mut Vec<LogicalPlan>, possible_join_keys: &JoinKeySet, all_join_keys: &mut JoinKeySet, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<LogicalPlan> { for (i, right_input) in rights.iter().enumerate() { let mut join_keys = vec![]; @@ -341,7 +341,7 @@ fn find_inner_join( on: join_keys, filter: None, schema: join_schema, - null_equals_null, + null_equality, })); } } @@ -363,7 +363,7 @@ fn find_inner_join( filter: None, join_type: JoinType::Inner, join_constraint: JoinConstraint::On, - null_equals_null, + null_equality, })) } @@ -1348,11 +1348,11 @@ mod tests { } #[test] - fn preserve_null_equals_null_setting() -> Result<()> { + fn preserve_null_equality_setting() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; - // Create an inner join with null_equals_null: true + // Create an inner join with NullEquality::NullEqualsNull let join_schema = Arc::new(build_join_schema( t1.schema(), t2.schema(), @@ -1367,7 +1367,7 @@ mod tests { on: vec![], filter: None, schema: join_schema, - null_equals_null: true, // Set to true to test preservation + null_equality: NullEquality::NullEqualsNull, // Test preservation }); // Apply filter that can create join conditions @@ -1382,31 +1382,31 @@ mod tests { let rule = EliminateCrossJoin::new(); let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data; - // Verify that null_equals_null is preserved in the optimized plan - fn check_null_equals_null_preserved(plan: &LogicalPlan) -> bool { + // Verify that null_equality is preserved in the optimized plan + fn check_null_equality_preserved(plan: &LogicalPlan) -> bool { match plan { LogicalPlan::Join(join) => { - // All joins in the optimized plan should preserve null_equals_null: true - if !join.null_equals_null { + // All joins in the optimized plan should preserve null equality + if join.null_equality == NullEquality::NullEqualsNothing { return false; } // Recursively check child plans plan.inputs() .iter() - .all(|input| check_null_equals_null_preserved(input)) + .all(|input| check_null_equality_preserved(input)) } _ => { // Recursively check child plans for non-join nodes plan.inputs() .iter() - .all(|input| check_null_equals_null_preserved(input)) + .all(|input| check_null_equality_preserved(input)) } } } assert!( - check_null_equals_null_preserved(&optimized_plan), - "null_equals_null setting should be preserved after optimization" + check_null_equality_preserved(&optimized_plan), + "null_equality setting should be preserved after optimization" ); Ok(()) diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 621086e4a2..45877642f2 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -118,7 +118,7 @@ impl OptimizerRule for EliminateOuterJoin { on: join.on.clone(), filter: join.filter.clone(), schema: Arc::clone(&join.schema), - null_equals_null: join.null_equals_null, + null_equality: join.null_equality, })); Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index a07b50ade5..55cf33ef43 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -75,7 +75,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -92,7 +92,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }))) } else { Ok(Transformed::no(LogicalPlan::Join(Join { @@ -103,7 +103,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }))) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 14a424b326..8ad7fa53c0 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -21,7 +21,7 @@ use crate::optimizer::ApplyOrder; use crate::push_down_filter::on_lr_is_preserved; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; +use datafusion_common::{NullEquality, Result}; use datafusion_expr::utils::conjunction; use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; use std::sync::Arc; @@ -51,7 +51,8 @@ impl OptimizerRule for FilterNullJoinKeys { } match plan { LogicalPlan::Join(mut join) - if !join.on.is_empty() && !join.null_equals_null => + if !join.on.is_empty() + && join.null_equality == NullEquality::NullEqualsNothing => { let (left_preserved, right_preserved) = on_lr_is_preserved(join.join_type); diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 566cf2f3a2..39eb557ea6 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -295,7 +295,7 @@ pub fn adjust_input_keys_ordering( join_type, projection, mode, - null_equals_null, + null_equality, .. }) = plan.as_any().downcast_ref::<HashJoinExec>() { @@ -314,7 +314,7 @@ pub fn adjust_input_keys_ordering( // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. projection.clone(), PartitionMode::Partitioned, - *null_equals_null, + *null_equality, ) .map(|e| Arc::new(e) as _) }; @@ -364,7 +364,7 @@ pub fn adjust_input_keys_ordering( filter, join_type, sort_options, - null_equals_null, + null_equality, .. }) = plan.as_any().downcast_ref::<SortMergeJoinExec>() { @@ -379,7 +379,7 @@ pub fn adjust_input_keys_ordering( filter.clone(), *join_type, new_conditions.1, - *null_equals_null, + *null_equality, ) .map(|e| Arc::new(e) as _) }; @@ -616,7 +616,7 @@ pub fn reorder_join_keys_to_inputs( join_type, projection, mode, - null_equals_null, + null_equality, .. }) = plan_any.downcast_ref::<HashJoinExec>() { @@ -642,7 +642,7 @@ pub fn reorder_join_keys_to_inputs( join_type, projection.clone(), PartitionMode::Partitioned, - *null_equals_null, + *null_equality, )?)); } } @@ -653,7 +653,7 @@ pub fn reorder_join_keys_to_inputs( filter, join_type, sort_options, - null_equals_null, + null_equality, .. }) = plan_any.downcast_ref::<SortMergeJoinExec>() { @@ -681,7 +681,7 @@ pub fn reorder_join_keys_to_inputs( filter.clone(), *join_type, new_sort_options, - *null_equals_null, + *null_equality, ) .map(|smj| Arc::new(smj) as _); } diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 27eed70241..dc22033214 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -245,7 +245,7 @@ pub(crate) fn try_collect_left( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::CollectLeft, - hash_join.null_equals_null(), + hash_join.null_equality(), )?))) } } @@ -257,7 +257,7 @@ pub(crate) fn try_collect_left( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::CollectLeft, - hash_join.null_equals_null(), + hash_join.null_equality(), )?))), (false, true) => { if hash_join.join_type().supports_swap() { @@ -292,7 +292,7 @@ pub(crate) fn partitioned_hash_join( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::Partitioned, - hash_join.null_equals_null(), + hash_join.null_equality(), )?)) } } @@ -474,7 +474,7 @@ fn hash_join_convert_symmetric_subrule( hash_join.on().to_vec(), hash_join.filter().cloned(), hash_join.join_type(), - hash_join.null_equals_null(), + hash_join.null_equality(), left_order, right_order, mode, diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 5034a199e2..8c4241be72 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -70,7 +70,7 @@ use arrow::util::bit_util; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, - JoinSide, JoinType, Result, + JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -353,11 +353,8 @@ pub struct HashJoinExec { pub projection: Option<Vec<usize>>, /// Information of index and left / right placement of columns column_indices: Vec<ColumnIndex>, - /// Null matching behavior: If `null_equals_null` is true, rows that have - /// `null`s in both left and right equijoin columns will be matched. - /// Otherwise, rows that have `null`s in the join columns will not be - /// matched and thus will not appear in the output. - pub null_equals_null: bool, + /// The equality null-handling behavior of the join algorithm. + pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, } @@ -376,7 +373,7 @@ impl HashJoinExec { join_type: &JoinType, projection: Option<Vec<usize>>, partition_mode: PartitionMode, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<Self> { let left_schema = left.schema(); let right_schema = right.schema(); @@ -419,7 +416,7 @@ impl HashJoinExec { metrics: ExecutionPlanMetricsSet::new(), projection, column_indices, - null_equals_null, + null_equality, cache, }) } @@ -460,9 +457,9 @@ impl HashJoinExec { &self.mode } - /// Get null_equals_null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// Calculate order preservation flags for this hash join. @@ -510,7 +507,7 @@ impl HashJoinExec { &self.join_type, projection, self.mode, - self.null_equals_null, + self.null_equality, ) } @@ -619,7 +616,7 @@ impl HashJoinExec { self.join_type(), ), partition_mode, - self.null_equals_null(), + self.null_equality(), )?; // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( @@ -767,7 +764,7 @@ impl ExecutionPlan for HashJoinExec { &self.join_type, self.projection.clone(), self.mode, - self.null_equals_null, + self.null_equality, )?)) } @@ -871,7 +868,7 @@ impl ExecutionPlan for HashJoinExec { column_indices: column_indices_after_projection, random_state: self.random_state.clone(), join_metrics, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: HashJoinStreamState::WaitBuildSide, build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, @@ -940,7 +937,7 @@ impl ExecutionPlan for HashJoinExec { // Returned early if projection is not None None, *self.partition_mode(), - self.null_equals_null, + self.null_equality, )?))) } else { try_embed_projection(projection, self) @@ -1231,8 +1228,8 @@ struct HashJoinStream { join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns column_indices: Vec<ColumnIndex>, - /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, + /// Defines the null equality for the join. + null_equality: NullEquality, /// State of the stream state: HashJoinStreamState, /// Build side @@ -1304,7 +1301,7 @@ fn lookup_join_hashmap( build_hashmap: &JoinHashMap, build_side_values: &[ArrayRef], probe_side_values: &[ArrayRef], - null_equals_null: bool, + null_equality: NullEquality, hashes_buffer: &[u64], limit: usize, offset: JoinHashMapOffset, @@ -1320,7 +1317,7 @@ fn lookup_join_hashmap( &probe_indices, build_side_values, probe_side_values, - null_equals_null, + null_equality, )?; Ok((build_indices, probe_indices, next_offset)) @@ -1330,22 +1327,21 @@ fn lookup_join_hashmap( fn eq_dyn_null( left: &dyn Array, right: &dyn Array, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<BooleanArray, ArrowError> { // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special // implementation // <https://github.com/apache/datafusion/issues/10749> if left.data_type().is_nested() { - let op = if null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq + let op = match null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, }; return Ok(compare_op_for_nested(op, &left, &right)?); } - match (left.data_type(), right.data_type()) { - _ if null_equals_null => not_distinct(&left, &right), - _ => eq(&left, &right), + match null_equality { + NullEquality::NullEqualsNothing => eq(&left, &right), + NullEquality::NullEqualsNull => not_distinct(&left, &right), } } @@ -1354,7 +1350,7 @@ pub fn equal_rows_arr( indices_right: &UInt32Array, left_arrays: &[ArrayRef], right_arrays: &[ArrayRef], - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<(UInt64Array, UInt32Array)> { let mut iter = left_arrays.iter().zip(right_arrays.iter()); @@ -1367,7 +1363,7 @@ pub fn equal_rows_arr( let arr_left = take(first_left.as_ref(), indices_left, None)?; let arr_right = take(first_right.as_ref(), indices_right, None)?; - let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equals_null)?; + let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?; // Use map and try_fold to iterate over the remaining pairs of arrays. // In each iteration, take is used on the pair of arrays and their equality is determined. @@ -1376,7 +1372,7 @@ pub fn equal_rows_arr( .map(|(left, right)| { let arr_left = take(left.as_ref(), indices_left, None)?; let arr_right = take(right.as_ref(), indices_right, None)?; - eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equals_null) + eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality) }) .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; @@ -1496,7 +1492,7 @@ impl HashJoinStream { build_side.left_data.hash_map(), build_side.left_data.values(), &state.values, - self.null_equals_null, + self.null_equality, &self.hashes_buffer, self.batch_size, state.offset, @@ -1726,7 +1722,7 @@ mod tests { right: Arc<dyn ExecutionPlan>, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<HashJoinExec> { HashJoinExec::try_new( left, @@ -1736,7 +1732,7 @@ mod tests { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + null_equality, ) } @@ -1746,7 +1742,7 @@ mod tests { on: JoinOn, filter: JoinFilter, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<HashJoinExec> { HashJoinExec::try_new( left, @@ -1756,7 +1752,7 @@ mod tests { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + null_equality, ) } @@ -1765,10 +1761,10 @@ mod tests { right: Arc<dyn ExecutionPlan>, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc<TaskContext>, ) -> Result<(Vec<String>, Vec<RecordBatch>)> { - let join = join(left, right, on, join_type, null_equals_null)?; + let join = join(left, right, on, join_type, null_equality)?; let columns_header = columns(&join.schema()); let stream = join.execute(0, context)?; @@ -1782,7 +1778,7 @@ mod tests { right: Arc<dyn ExecutionPlan>, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc<TaskContext>, ) -> Result<(Vec<String>, Vec<RecordBatch>)> { join_collect_with_partition_mode( @@ -1791,7 +1787,7 @@ mod tests { on, join_type, PartitionMode::Partitioned, - null_equals_null, + null_equality, context, ) .await @@ -1803,7 +1799,7 @@ mod tests { on: JoinOn, join_type: &JoinType, partition_mode: PartitionMode, - null_equals_null: bool, + null_equality: NullEquality, context: Arc<TaskContext>, ) -> Result<(Vec<String>, Vec<RecordBatch>)> { let partition_count = 4; @@ -1853,7 +1849,7 @@ mod tests { join_type, None, partition_mode, - null_equals_null, + null_equality, )?; let columns = columns(&join.schema()); @@ -1898,7 +1894,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Inner, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -1945,7 +1941,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Inner, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -1985,8 +1981,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2024,8 +2027,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2071,8 +2081,15 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -2142,8 +2159,15 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -2208,8 +2232,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2258,7 +2289,13 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2351,7 +2388,14 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Left, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Left, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2394,7 +2438,14 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Full, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2435,7 +2486,14 @@ mod tests { )]; let schema = right.schema(); let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap(); - let join = join(left, right, on, &JoinType::Left, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Left, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2472,7 +2530,14 @@ mod tests { )]; let schema = right.schema(); let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap(); - let join = join(left, right, on, &JoinType::Full, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2517,7 +2582,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Left, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -2562,7 +2627,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Left, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -2615,7 +2680,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::LeftSemi, false)?; + let join = join( + left, + right, + on, + &JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2677,7 +2748,7 @@ mod tests { on.clone(), filter, &JoinType::LeftSemi, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -2710,7 +2781,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); @@ -2744,7 +2822,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::RightSemi, false)?; + let join = join( + left, + right, + on, + &JoinType::RightSemi, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); @@ -2806,7 +2890,7 @@ mod tests { on.clone(), filter, &JoinType::RightSemi, - false, + NullEquality::NullEqualsNothing, )?; let columns = columns(&join.schema()); @@ -2841,8 +2925,14 @@ mod tests { Arc::new(intermediate_schema.clone()), ); - let join = - join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::RightSemi, + NullEquality::NullEqualsNothing, + )?; let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; @@ -2873,7 +2963,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::LeftAnti, false)?; + let join = join( + left, + right, + on, + &JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2932,7 +3028,7 @@ mod tests { on.clone(), filter, &JoinType::LeftAnti, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -2969,7 +3065,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); @@ -3006,7 +3109,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::RightAnti, false)?; + let join = join( + left, + right, + on, + &JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); @@ -3066,7 +3175,7 @@ mod tests { on.clone(), filter, &JoinType::RightAnti, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -3107,8 +3216,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = - join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a2", "b2", "c2"]); @@ -3152,8 +3267,15 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Right, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -3191,9 +3313,15 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right, false, task_ctx) - .await?; + let (columns, batches) = partitioned_join_collect( + left, + right, + on, + &JoinType::Right, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -3231,7 +3359,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Full, false)?; + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -3279,7 +3413,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::LeftMark, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -3324,7 +3458,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::LeftMark, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -3488,7 +3622,7 @@ mod tests { &join_hash_map, &[left_keys_values], &[right_keys_values], - false, + NullEquality::NullEqualsNothing, &hashes_buffer, 8192, (0, None), @@ -3524,7 +3658,13 @@ mod tests { Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3594,7 +3734,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Inner, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3636,7 +3783,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Left, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Left, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3681,7 +3835,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Right, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Right, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3725,7 +3886,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Full, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Full, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3892,7 +4060,7 @@ mod tests { on.clone(), &join_type, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, Arc::clone(&task_ctx), ) .await?; @@ -3924,7 +4092,13 @@ mod tests { Arc::new(Column::new_with_schema("date", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let task_ctx = Arc::new(TaskContext::default()); let stream = join.execute(0, task_ctx)?; @@ -3983,7 +4157,7 @@ mod tests { Arc::clone(&right_input) as Arc<dyn ExecutionPlan>, on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); let task_ctx = Arc::new(TaskContext::default()); @@ -4097,7 +4271,7 @@ mod tests { Arc::clone(&right), on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -4177,7 +4351,7 @@ mod tests { Arc::clone(&right), on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4258,7 +4432,7 @@ mod tests { &join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(1, task_ctx)?; @@ -4318,8 +4492,15 @@ mod tests { Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["n1", "n2"]); @@ -4355,7 +4536,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, Arc::clone(&task_ctx), ) .await?; @@ -4370,8 +4551,15 @@ mod tests { "#); } - let (_, batches_null_neq) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (_, batches_null_neq) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; let expected_null_neq = ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 44021c38a7..fcc1107a0e 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -720,7 +720,7 @@ struct NestedLoopJoinStream<T> { /// Information of index and left / right placement of columns column_indices: Vec<ColumnIndex>, // TODO: support null aware equal - // null_equals_null: bool + // null_equality: NullEquality, /// Join execution metrics join_metrics: BuildProbeJoinMetrics, /// Cache for join indices calculations diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 6ab069aaf4..4d635948ed 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -63,7 +63,7 @@ use arrow::error::ArrowError; use arrow::ipc::reader::StreamReader; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, JoinSide, - JoinType, Result, + JoinType, NullEquality, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -145,8 +145,8 @@ pub struct SortMergeJoinExec { right_sort_exprs: LexOrdering, /// Sort options of join columns used in sorting left and right execution plans pub sort_options: Vec<SortOptions>, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: bool, + /// Defines the null equality for the join. + pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, } @@ -163,7 +163,7 @@ impl SortMergeJoinExec { filter: Option<JoinFilter>, join_type: JoinType, sort_options: Vec<SortOptions>, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<Self> { let left_schema = left.schema(); let right_schema = right.schema(); @@ -218,7 +218,7 @@ impl SortMergeJoinExec { left_sort_exprs, right_sort_exprs, sort_options, - null_equals_null, + null_equality, cache, }) } @@ -291,9 +291,9 @@ impl SortMergeJoinExec { &self.sort_options } - /// Null equals null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Null equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -339,7 +339,7 @@ impl SortMergeJoinExec { self.filter().as_ref().map(JoinFilter::swap), self.join_type().swap(), self.sort_options.clone(), - self.null_equals_null, + self.null_equality, )?; // TODO: OR this condition with having a built-in projection (like @@ -450,7 +450,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.filter.clone(), self.join_type, self.sort_options.clone(), - self.null_equals_null, + self.null_equality, )?)), _ => internal_err!("SortMergeJoin wrong number of children"), } @@ -502,7 +502,7 @@ impl ExecutionPlan for SortMergeJoinExec { Ok(Box::pin(SortMergeJoinStream::try_new( Arc::clone(&self.schema), self.sort_options.clone(), - self.null_equals_null, + self.null_equality, streamed, buffered, on_streamed, @@ -591,7 +591,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.filter.clone(), self.join_type, self.sort_options.clone(), - self.null_equals_null, + self.null_equality, )?))) } } @@ -844,8 +844,8 @@ struct SortMergeJoinStream { // ======================================================================== /// Output schema pub schema: SchemaRef, - /// null == null? - pub null_equals_null: bool, + /// Defines the null equality for the join. + pub null_equality: NullEquality, /// Sort options of join columns used to sort streamed and buffered data stream pub sort_options: Vec<SortOptions>, /// optional join filter @@ -1326,7 +1326,7 @@ impl SortMergeJoinStream { pub fn try_new( schema: SchemaRef, sort_options: Vec<SortOptions>, - null_equals_null: bool, + null_equality: NullEquality, streamed: SendableRecordBatchStream, buffered: SendableRecordBatchStream, on_streamed: Vec<Arc<dyn PhysicalExpr>>, @@ -1348,7 +1348,7 @@ impl SortMergeJoinStream { Ok(Self { state: SortMergeJoinState::Init, sort_options, - null_equals_null, + null_equality, schema: Arc::clone(&schema), streamed_schema: Arc::clone(&streamed_schema), buffered_schema, @@ -1593,7 +1593,7 @@ impl SortMergeJoinStream { &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, - self.null_equals_null, + self.null_equality, ) } @@ -2434,7 +2434,7 @@ fn compare_join_arrays( right_arrays: &[ArrayRef], right: usize, sort_options: &[SortOptions], - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<Ordering> { let mut res = Ordering::Equal; for ((left_array, right_array), sort_options) in @@ -2468,10 +2468,9 @@ fn compare_join_arrays( }; } _ => { - res = if null_equals_null { - Ordering::Equal - } else { - Ordering::Less + res = match null_equality { + NullEquality::NullEqualsNothing => Ordering::Less, + NullEquality::NullEqualsNull => Ordering::Equal, }; } } @@ -2597,7 +2596,9 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::JoinType::*; - use datafusion_common::{assert_batches_eq, assert_contains, JoinType, Result}; + use datafusion_common::{ + assert_batches_eq, assert_contains, JoinType, NullEquality, Result, + }; use datafusion_common::{ test_util::{batches_to_sort_string, batches_to_string}, JoinSide, @@ -2722,7 +2723,15 @@ mod tests { join_type: JoinType, ) -> Result<SortMergeJoinExec> { let sort_options = vec![SortOptions::default(); on.len()]; - SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false) + SortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + ) } fn join_with_options( @@ -2731,7 +2740,7 @@ mod tests { on: JoinOn, join_type: JoinType, sort_options: Vec<SortOptions>, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<SortMergeJoinExec> { SortMergeJoinExec::try_new( left, @@ -2740,7 +2749,7 @@ mod tests { None, join_type, sort_options, - null_equals_null, + null_equality, ) } @@ -2751,7 +2760,7 @@ mod tests { filter: JoinFilter, join_type: JoinType, sort_options: Vec<SortOptions>, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<SortMergeJoinExec> { SortMergeJoinExec::try_new( left, @@ -2760,7 +2769,7 @@ mod tests { Some(filter), join_type, sort_options, - null_equals_null, + null_equality, ) } @@ -2771,7 +2780,15 @@ mod tests { join_type: JoinType, ) -> Result<(Vec<String>, Vec<RecordBatch>)> { let sort_options = vec![SortOptions::default(); on.len()]; - join_collect_with_options(left, right, on, join_type, sort_options, false).await + join_collect_with_options( + left, + right, + on, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + ) + .await } async fn join_collect_with_filter( @@ -2784,8 +2801,15 @@ mod tests { let sort_options = vec![SortOptions::default(); on.len()]; let task_ctx = Arc::new(TaskContext::default()); - let join = - join_with_filter(left, right, on, filter, join_type, sort_options, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); let stream = join.execute(0, task_ctx)?; @@ -2799,17 +2823,11 @@ mod tests { on: JoinOn, join_type: JoinType, sort_options: Vec<SortOptions>, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<(Vec<String>, Vec<RecordBatch>)> { let task_ctx = Arc::new(TaskContext::default()); - let join = join_with_options( - left, - right, - on, - join_type, - sort_options, - null_equals_null, - )?; + let join = + join_with_options(left, right, on, join_type, sort_options, null_equality)?; let columns = columns(&join.schema()); let stream = join.execute(0, task_ctx)?; @@ -3015,7 +3033,7 @@ mod tests { }; 2 ], - true, + NullEquality::NullEqualsNull, ) .await?; // The output order is important as SMJ preserves sortedness @@ -3438,7 +3456,7 @@ mod tests { }; 2 ], - true, + NullEquality::NullEqualsNull, ) .await?; @@ -3715,7 +3733,7 @@ mod tests { }; 2 ], - true, + NullEquality::NullEqualsNull, ) .await?; @@ -4159,7 +4177,7 @@ mod tests { on.clone(), join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4240,7 +4258,7 @@ mod tests { on.clone(), join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4303,7 +4321,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4325,7 +4343,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx_no_spill)?; let no_spilled_join_result = common::collect(stream).await.unwrap(); @@ -4407,7 +4425,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4428,7 +4446,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx_no_spill)?; let no_spilled_join_result = common::collect(stream).await.unwrap(); diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 84575acea5..d540b6d2a3 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -67,7 +67,9 @@ use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; -use datafusion_common::{internal_err, plan_err, HashSet, JoinSide, JoinType, Result}; +use datafusion_common::{ + internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, Result, +}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_expr::interval_arithmetic::Interval; @@ -185,8 +187,8 @@ pub struct SymmetricHashJoinExec { metrics: ExecutionPlanMetricsSet, /// Information of index and left / right placement of columns column_indices: Vec<ColumnIndex>, - /// If null_equals_null is true, null == null else null != null - pub(crate) null_equals_null: bool, + /// Defines the null equality for the join. + pub(crate) null_equality: NullEquality, /// Left side sort expression(s) pub(crate) left_sort_exprs: Option<LexOrdering>, /// Right side sort expression(s) @@ -211,7 +213,7 @@ impl SymmetricHashJoinExec { on: JoinOn, filter: Option<JoinFilter>, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, left_sort_exprs: Option<LexOrdering>, right_sort_exprs: Option<LexOrdering>, mode: StreamJoinPartitionMode, @@ -246,7 +248,7 @@ impl SymmetricHashJoinExec { random_state, metrics: ExecutionPlanMetricsSet::new(), column_indices, - null_equals_null, + null_equality, left_sort_exprs, right_sort_exprs, mode, @@ -310,9 +312,9 @@ impl SymmetricHashJoinExec { &self.join_type } - /// Get null_equals_null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// Get partition mode @@ -456,7 +458,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.on.clone(), self.filter.clone(), &self.join_type, - self.null_equals_null, + self.null_equality, self.left_sort_exprs.clone(), self.right_sort_exprs.clone(), self.mode, @@ -545,7 +547,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { graph, left_sorted_filter_expr, right_sorted_filter_expr, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: SHJStreamState::PullRight, reservation, batch_transformer: BatchSplitter::new(batch_size), @@ -565,7 +567,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { graph, left_sorted_filter_expr, right_sorted_filter_expr, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: SHJStreamState::PullRight, reservation, batch_transformer: NoopBatchTransformer::new(), @@ -637,7 +639,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { new_on, new_filter, self.join_type(), - self.null_equals_null(), + self.null_equality(), self.right().output_ordering().cloned(), self.left().output_ordering().cloned(), self.partition_mode(), @@ -671,8 +673,8 @@ struct SymmetricHashJoinStream<T> { right_sorted_filter_expr: Option<SortedFilterExpr>, /// Random state used for hashing initialization random_state: RandomState, - /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, + /// Defines the null equality for the join. + null_equality: NullEquality, /// Metrics metrics: StreamJoinMetrics, /// Memory reservation @@ -934,7 +936,7 @@ pub(crate) fn build_side_determined_results( /// * `probe_batch` - The second record batch to be joined. /// * `column_indices` - An array of columns to be selected for the result of the join. /// * `random_state` - The random state for the join. -/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining. /// /// # Returns /// @@ -950,7 +952,7 @@ pub(crate) fn join_with_probe_batch( probe_batch: &RecordBatch, column_indices: &[ColumnIndex], random_state: &RandomState, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<Option<RecordBatch>> { if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); @@ -962,7 +964,7 @@ pub(crate) fn join_with_probe_batch( &build_hash_joiner.on, &probe_hash_joiner.on, random_state, - null_equals_null, + null_equality, &mut build_hash_joiner.hashes_buffer, Some(build_hash_joiner.deleted_offset), )?; @@ -1028,7 +1030,7 @@ pub(crate) fn join_with_probe_batch( /// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join. /// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join. /// * `random_state` - The random state for the join. -/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining. /// * `hashes_buffer` - Buffer used for probe side keys hash calculation. /// * `deleted_offset` - deleted offset for build side data. /// @@ -1044,7 +1046,7 @@ fn lookup_join_hashmap( build_on: &[PhysicalExprRef], probe_on: &[PhysicalExprRef], random_state: &RandomState, - null_equals_null: bool, + null_equality: NullEquality, hashes_buffer: &mut Vec<u64>, deleted_offset: Option<usize>, ) -> Result<(UInt64Array, UInt32Array)> { @@ -1105,7 +1107,7 @@ fn lookup_join_hashmap( &probe_indices, &build_join_values, &keys_values, - null_equals_null, + null_equality, )?; Ok((build_indices, probe_indices)) @@ -1602,7 +1604,7 @@ impl<T: BatchTransformer> SymmetricHashJoinStream<T> { size += size_of_val(&self.left_sorted_filter_expr); size += size_of_val(&self.right_sorted_filter_expr); size += size_of_val(&self.random_state); - size += size_of_val(&self.null_equals_null); + size += size_of_val(&self.null_equality); size += size_of_val(&self.metrics); size } @@ -1657,7 +1659,7 @@ impl<T: BatchTransformer> SymmetricHashJoinStream<T> { &probe_batch, &self.column_indices, &self.random_state, - self.null_equals_null, + self.null_equality, )?; // Increment the offset for the probe hash joiner: probe_hash_joiner.offset += probe_batch.num_rows(); @@ -1813,12 +1815,18 @@ mod tests { on.clone(), filter.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, Arc::clone(&task_ctx), ) .await?; let second_batches = partitioned_hash_join_with_filter( - left, right, on, filter, &join_type, false, task_ctx, + left, + right, + on, + filter, + &join_type, + NullEquality::NullEqualsNothing, + task_ctx, ) .await?; compare_batches(&first_batches, &second_batches); diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index cbabd7cb45..ea893cc933 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -33,7 +33,7 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Schema}; use arrow::util::pretty::pretty_format_batches; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{NullEquality, Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{binary, cast, col, lit}; @@ -74,7 +74,7 @@ pub async fn partitioned_sym_join_with_filter( on: JoinOn, filter: Option<JoinFilter>, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc<TaskContext>, ) -> Result<Vec<RecordBatch>> { let partition_count = 4; @@ -101,7 +101,7 @@ pub async fn partitioned_sym_join_with_filter( on, filter, join_type, - null_equals_null, + null_equality, left.output_ordering().cloned(), right.output_ordering().cloned(), StreamJoinPartitionMode::Partitioned, @@ -128,7 +128,7 @@ pub async fn partitioned_hash_join_with_filter( on: JoinOn, filter: Option<JoinFilter>, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc<TaskContext>, ) -> Result<Vec<RecordBatch>> { let partition_count = 4; @@ -151,7 +151,7 @@ pub async fn partitioned_hash_join_with_filter( join_type, None, PartitionMode::Partitioned, - null_equals_null, + null_equality, )?); let mut batches = vec![]; diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 9eab33928a..81fc9cceb7 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -93,6 +93,11 @@ enum JoinConstraint { USING = 1; } +enum NullEquality { + NULL_EQUALS_NOTHING = 0; + NULL_EQUALS_NULL = 1; +} + message AvroOptions {} message ArrowOptions {} diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 0c593a36b8..c3b6686df0 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -4418,6 +4418,77 @@ impl<'de> serde::Deserialize<'de> for NdJsonFormat { deserializer.deserialize_struct("datafusion_common.NdJsonFormat", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NullEquality { + #[allow(deprecated)] + fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + let variant = match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for NullEquality { + #[allow(deprecated)] + fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "NULL_EQUALS_NOTHING", + "NULL_EQUALS_NULL", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NullEquality; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + match value { + "NULL_EQUALS_NOTHING" => Ok(NullEquality::NullEqualsNothing), + "NULL_EQUALS_NULL" => Ok(NullEquality::NullEqualsNull), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for ParquetColumnOptions { #[allow(deprecated)] fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index c051dd00f7..411d72af4c 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -970,6 +970,32 @@ impl JoinConstraint { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum NullEquality { + NullEqualsNothing = 0, + NullEqualsNull = 1, +} +impl NullEquality { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option<Self> { + match value { + "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing), + "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum TimeUnit { Second = 0, Millisecond = 1, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 76cd8c9118..1e1f91e07e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -243,7 +243,7 @@ message JoinNode { datafusion_common.JoinConstraint join_constraint = 4; repeated LogicalExprNode left_join_key = 5; repeated LogicalExprNode right_join_key = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; LogicalExprNode filter = 8; } @@ -1051,7 +1051,7 @@ message HashJoinExecNode { repeated JoinOn on = 3; datafusion_common.JoinType join_type = 4; PartitionMode partition_mode = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated uint32 projection = 9; } @@ -1067,7 +1067,7 @@ message SymmetricHashJoinExecNode { repeated JoinOn on = 3; datafusion_common.JoinType join_type = 4; StreamPartitionMode partition_mode = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated PhysicalSortExprNode left_sort_exprs = 9; repeated PhysicalSortExprNode right_sort_exprs = 10; diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index c051dd00f7..411d72af4c 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -970,6 +970,32 @@ impl JoinConstraint { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum NullEquality { + NullEqualsNothing = 0, + NullEqualsNull = 1, +} +impl NullEquality { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option<Self> { + match value { + "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing), + "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum TimeUnit { Second = 0, Millisecond = 1, diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8a62be84ec..02a1cc70ee 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6851,7 +6851,7 @@ impl serde::Serialize for HashJoinExecNode { if self.partition_mode != 0 { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -6880,8 +6880,10 @@ impl serde::Serialize for HashJoinExecNode { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; struct_ser.serialize_field("partitionMode", &v)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -6906,8 +6908,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "joinType", "partition_mode", "partitionMode", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", "projection", ]; @@ -6919,7 +6921,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { On, JoinType, PartitionMode, - NullEqualsNull, + NullEquality, Filter, Projection, } @@ -6948,7 +6950,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "on" => Ok(GeneratedField::On), "joinType" | "join_type" => Ok(GeneratedField::JoinType), "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "projection" => Ok(GeneratedField::Projection), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -6975,7 +6977,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut on__ = None; let mut join_type__ = None; let mut partition_mode__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; let mut projection__ = None; while let Some(k) = map_.next_key()? { @@ -7010,11 +7012,11 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { } partition_mode__ = Some(map_.next_value::<PartitionMode>()? as i32); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -7039,7 +7041,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { on: on__.unwrap_or_default(), join_type: join_type__.unwrap_or_default(), partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, projection: projection__.unwrap_or_default(), }) @@ -8475,7 +8477,7 @@ impl serde::Serialize for JoinNode { if !self.right_join_key.is_empty() { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -8504,8 +8506,10 @@ impl serde::Serialize for JoinNode { if !self.right_join_key.is_empty() { struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -8530,8 +8534,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode { "leftJoinKey", "right_join_key", "rightJoinKey", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", ]; @@ -8543,7 +8547,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { JoinConstraint, LeftJoinKey, RightJoinKey, - NullEqualsNull, + NullEquality, Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -8572,7 +8576,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { "joinConstraint" | "join_constraint" => Ok(GeneratedField::JoinConstraint), "leftJoinKey" | "left_join_key" => Ok(GeneratedField::LeftJoinKey), "rightJoinKey" | "right_join_key" => Ok(GeneratedField::RightJoinKey), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -8599,7 +8603,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { let mut join_constraint__ = None; let mut left_join_key__ = None; let mut right_join_key__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { @@ -8639,11 +8643,11 @@ impl<'de> serde::Deserialize<'de> for JoinNode { } right_join_key__ = Some(map_.next_value()?); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -8660,7 +8664,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { join_constraint: join_constraint__.unwrap_or_default(), left_join_key: left_join_key__.unwrap_or_default(), right_join_key: right_join_key__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, }) } @@ -20041,7 +20045,7 @@ impl serde::Serialize for SymmetricHashJoinExecNode { if self.partition_mode != 0 { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -20073,8 +20077,10 @@ impl serde::Serialize for SymmetricHashJoinExecNode { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; struct_ser.serialize_field("partitionMode", &v)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -20102,8 +20108,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "joinType", "partition_mode", "partitionMode", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", "left_sort_exprs", "leftSortExprs", @@ -20118,7 +20124,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { On, JoinType, PartitionMode, - NullEqualsNull, + NullEquality, Filter, LeftSortExprs, RightSortExprs, @@ -20148,7 +20154,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "on" => Ok(GeneratedField::On), "joinType" | "join_type" => Ok(GeneratedField::JoinType), "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "leftSortExprs" | "left_sort_exprs" => Ok(GeneratedField::LeftSortExprs), "rightSortExprs" | "right_sort_exprs" => Ok(GeneratedField::RightSortExprs), @@ -20176,7 +20182,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { let mut on__ = None; let mut join_type__ = None; let mut partition_mode__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; let mut left_sort_exprs__ = None; let mut right_sort_exprs__ = None; @@ -20212,11 +20218,11 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { } partition_mode__ = Some(map_.next_value::<StreamPartitionMode>()? as i32); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -20244,7 +20250,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { on: on__.unwrap_or_default(), join_type: join_type__.unwrap_or_default(), partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, left_sort_exprs: left_sort_exprs__.unwrap_or_default(), right_sort_exprs: right_sort_exprs__.unwrap_or_default(), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3e3a04051f..c1f8fa61f3 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -369,8 +369,8 @@ pub struct JoinNode { pub left_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>, #[prost(message, repeated, tag = "6")] pub right_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option<LogicalExprNode>, } @@ -1592,8 +1592,8 @@ pub struct HashJoinExecNode { pub join_type: i32, #[prost(enumeration = "PartitionMode", tag = "6")] pub partition_mode: i32, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option<JoinFilter>, #[prost(uint32, repeated, tag = "9")] @@ -1611,8 +1611,8 @@ pub struct SymmetricHashJoinExecNode { pub join_type: i32, #[prost(enumeration = "StreamPartitionMode", tag = "6")] pub partition_mode: i32, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option<JoinFilter>, #[prost(message, repeated, tag = "9")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 162ed7ae25..66ef0ebfe3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, - Result, ScalarValue, TableReference, UnnestOptions, + exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, + RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, }; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, Placeholder, Sort}; @@ -219,6 +219,15 @@ impl From<protobuf::JoinConstraint> for JoinConstraint { } } +impl From<protobuf::NullEquality> for NullEquality { + fn from(t: protobuf::NullEquality) -> Self { + match t { + protobuf::NullEquality::NullEqualsNothing => NullEquality::NullEqualsNothing, + protobuf::NullEquality::NullEqualsNull => NullEquality::NullEqualsNull, + } + } +} + impl From<protobuf::dml_node::Type> for WriteOp { fn from(t: protobuf::dml_node::Type) -> Self { match t { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index d934b24dc3..1acf1ee27b 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1327,7 +1327,7 @@ impl AsLogicalPlan for LogicalPlanNode { filter, join_type, join_constraint, - null_equals_null, + null_equality, .. }) => { let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( @@ -1352,6 +1352,8 @@ impl AsLogicalPlan for LogicalPlanNode { let join_type: protobuf::JoinType = join_type.to_owned().into(); let join_constraint: protobuf::JoinConstraint = join_constraint.to_owned().into(); + let null_equality: protobuf::NullEquality = + null_equality.to_owned().into(); let filter = filter .as_ref() .map(|e| serialize_expr(e, extension_codec)) @@ -1365,7 +1367,7 @@ impl AsLogicalPlan for LogicalPlanNode { join_constraint: join_constraint.into(), left_join_key, right_join_key, - null_equals_null: *null_equals_null, + null_equality: null_equality.into(), filter, }, ))), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index f7f303fefe..b14ad7aadf 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; -use datafusion_common::{TableReference, UnnestOptions}; +use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, @@ -699,6 +699,15 @@ impl From<JoinConstraint> for protobuf::JoinConstraint { } } +impl From<NullEquality> for protobuf::NullEquality { + fn from(t: NullEquality) -> Self { + match t { + NullEquality::NullEqualsNothing => protobuf::NullEquality::NullEqualsNothing, + NullEquality::NullEqualsNull => protobuf::NullEquality::NullEqualsNull, + } + } +} + impl From<&WriteOp> for protobuf::dml_node::Type { fn from(t: &WriteOp) -> Self { match t { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e2c391d044..3d541f54fe 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1152,6 +1152,13 @@ impl protobuf::PhysicalPlanNode { hashjoin.join_type )) })?; + let null_equality = protobuf::NullEquality::try_from(hashjoin.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a HashJoinNode message with unknown NullEquality {}", + hashjoin.null_equality + )) + })?; let filter = hashjoin .filter .as_ref() @@ -1220,7 +1227,7 @@ impl protobuf::PhysicalPlanNode { &join_type.into(), projection, partition_mode, - hashjoin.null_equals_null, + null_equality.into(), )?)) } @@ -1263,6 +1270,13 @@ impl protobuf::PhysicalPlanNode { sym_join.join_type )) })?; + let null_equality = protobuf::NullEquality::try_from(sym_join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown NullEquality {}", + sym_join.null_equality + )) + })?; let filter = sym_join .filter .as_ref() @@ -1339,7 +1353,7 @@ impl protobuf::PhysicalPlanNode { on, filter, &join_type.into(), - sym_join.null_equals_null, + null_equality.into(), left_sort_exprs, right_sort_exprs, partition_mode, @@ -1950,6 +1964,7 @@ impl protobuf::PhysicalPlanNode { }) .collect::<Result<_>>()?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); let filter = exec .filter() .as_ref() @@ -1990,7 +2005,7 @@ impl protobuf::PhysicalPlanNode { on, join_type: join_type.into(), partition_mode: partition_mode.into(), - null_equals_null: exec.null_equals_null(), + null_equality: null_equality.into(), filter, projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::<Vec<u32>>() @@ -2025,6 +2040,7 @@ impl protobuf::PhysicalPlanNode { }) .collect::<Result<_>>()?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); let filter = exec .filter() .as_ref() @@ -2108,7 +2124,7 @@ impl protobuf::PhysicalPlanNode { on, join_type: join_type.into(), partition_mode: partition_mode.into(), - null_equals_null: exec.null_equals_null(), + null_equality: null_equality.into(), left_sort_exprs, right_sort_exprs, filter, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index ab419f3713..43f9942a0a 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -98,7 +98,7 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_err, not_impl_err, DataFusionError, Result, UnnestOptions, + internal_err, not_impl_err, DataFusionError, NullEquality, Result, UnnestOptions, }; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, @@ -268,7 +268,7 @@ fn roundtrip_hash_join() -> Result<()> { join_type, None, *partition_mode, - false, + NullEquality::NullEqualsNothing, )?))?; } } @@ -1494,7 +1494,7 @@ fn roundtrip_sym_hash_join() -> Result<()> { on.clone(), None, join_type, - false, + NullEquality::NullEqualsNothing, left_order.clone(), right_order, *partition_mode, diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ccecb94943..3be5c1b1c3 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1373,7 +1373,7 @@ inner join join_t4 on join_t3.s3 = join_t4.s4 {id: 2} {id: 2} # join with struct key and nulls -# Note that intersect or except applies `null_equals_null` as true for Join. +# Note that intersect or except applies `null_equality` as `NullEquality::NullEqualsNull` for Join. query ? SELECT * FROM join_t3 EXCEPT diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs index fab43a5ff4..0cf920dd62 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs @@ -17,7 +17,7 @@ use crate::logical_plan::consumer::utils::requalify_sides_if_needed; use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{not_impl_err, plan_err, Column, JoinType}; +use datafusion::common::{not_impl_err, plan_err, Column, JoinType, NullEquality}; use datafusion::logical_expr::utils::split_conjunction; use datafusion::logical_expr::{ BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, @@ -59,19 +59,30 @@ pub async fn from_join_rel( split_eq_and_noneq_join_predicate_with_nulls_equality(&on); let (left_cols, right_cols): (Vec<_>, Vec<_>) = itertools::multiunzip(join_ons); + let null_equality = if nulls_equal_nulls { + NullEquality::NullEqualsNull + } else { + NullEquality::NullEqualsNothing + }; left.join_detailed( right.build()?, join_type, (left_cols, right_cols), join_filter, - nulls_equal_nulls, + null_equality, )? .build() } None => { let on: Vec<String> = vec![]; - left.join_detailed(right.build()?, join_type, (on.clone(), on), None, false)? - .build() + left.join_detailed( + right.build()?, + join_type, + (on.clone(), on), + None, + NullEquality::NullEqualsNothing, + )? + .build() } } } diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs index 65c3e426d2..3dbac636fe 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/join.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -16,7 +16,9 @@ // under the License. use crate::logical_plan::producer::{make_binary_op_scalar_func, SubstraitProducer}; -use datafusion::common::{not_impl_err, DFSchemaRef, JoinConstraint, JoinType}; +use datafusion::common::{ + not_impl_err, DFSchemaRef, JoinConstraint, JoinType, NullEquality, +}; use datafusion::logical_expr::{Expr, Join, Operator}; use std::sync::Arc; use substrait::proto::rel::RelType; @@ -44,10 +46,9 @@ pub fn from_join( // map the left and right columns to binary expressions in the form `l = r` // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = if join.null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq + let eq_op = match join.null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, }; let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org