This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 7d16764ec5 feat: Support RightMark join for NestedLoop and Hash join (#16083) 7d16764ec5 is described below commit 7d16764ec5f407db4f0d9b6eb5c0970913fc894b Author: Jonathan Chen <chenleejonat...@gmail.com> AuthorDate: Mon Jun 16 17:04:12 2025 -0400 feat: Support RightMark join for NestedLoop and Hash join (#16083) * feat: Support RightMark join for NestedLoop and Hash join * fixes * producer fix * fmt * update * fix * rem file * fix * fmt * Update datafusion/physical-plan/src/joins/utils.rs Co-authored-by: Christian <9384305+c...@users.noreply.github.com> * fixes * clippy * refactor --------- Co-authored-by: Christian <9384305+c...@users.noreply.github.com> Co-authored-by: Oleks V <comph...@users.noreply.github.com> --- datafusion/common/src/functional_dependencies.rs | 2 +- datafusion/common/src/join_type.rs | 14 ++- datafusion/core/tests/dataframe/mod.rs | 4 +- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 25 ++++ .../physical_optimizer/enforce_distribution.rs | 5 +- datafusion/expr/src/logical_plan/builder.rs | 4 + datafusion/expr/src/logical_plan/invariants.rs | 5 +- datafusion/expr/src/logical_plan/plan.rs | 8 +- .../optimizer/src/optimize_projections/mod.rs | 11 +- datafusion/optimizer/src/push_down_filter.rs | 5 +- datafusion/physical-expr/src/equivalence/class.rs | 4 +- .../physical-optimizer/src/enforce_distribution.rs | 2 +- .../src/enforce_sorting/sort_pushdown.rs | 3 +- datafusion/physical-plan/src/joins/hash_join.rs | 134 +++++++++++++++++++-- .../physical-plan/src/joins/nested_loop_join.rs | 67 +++++++++-- .../physical-plan/src/joins/sort_merge_join.rs | 13 +- .../physical-plan/src/joins/symmetric_hash_join.rs | 20 ++- datafusion/physical-plan/src/joins/utils.rs | 93 ++++++++++---- .../proto-common/proto/datafusion_common.proto | 1 + datafusion/proto-common/src/generated/pbjson.rs | 3 + datafusion/proto-common/src/generated/prost.rs | 3 + .../proto/src/generated/datafusion_proto_common.rs | 3 + datafusion/proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sql/src/unparser/plan.rs | 36 ++++-- .../src/logical_plan/consumer/rel/join_rel.rs | 1 + .../src/logical_plan/producer/rel/join.rs | 1 + 27 files changed, 387 insertions(+), 82 deletions(-) diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 737a9c35f7..63962998ad 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -364,7 +364,7 @@ impl FunctionalDependencies { // These joins preserve functional dependencies of the left side: left_func_dependencies } - JoinType::RightSemi | JoinType::RightAnti => { + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { // These joins preserve functional dependencies of the right side: right_func_dependencies } diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index ac81d977b7..d9a1478f02 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -67,6 +67,11 @@ pub enum JoinType { /// /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf LeftMark, + /// Right Mark Join + /// + /// Same logic as the LeftMark Join above, however it returns a record for each record from the + /// right input. + RightMark, } impl JoinType { @@ -87,13 +92,12 @@ impl JoinType { JoinType::RightSemi => JoinType::LeftSemi, JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightAnti => JoinType::LeftAnti, - JoinType::LeftMark => { - unreachable!("LeftMark join type does not support swapping") - } + JoinType::LeftMark => JoinType::RightMark, + JoinType::RightMark => JoinType::LeftMark, } } - /// Does the join type support swapping inputs? + /// Does the join type support swapping inputs? pub fn supports_swap(&self) -> bool { matches!( self, @@ -121,6 +125,7 @@ impl Display for JoinType { JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", JoinType::LeftMark => "LeftMark", + JoinType::RightMark => "RightMark", }; write!(f, "{join_type}") } @@ -141,6 +146,7 @@ impl FromStr for JoinType { "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), "LEFTMARK" => Ok(JoinType::LeftMark), + "RIGHTMARK" => Ok(JoinType::RightMark), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index f198907cf5..3c4f8018a1 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -2145,6 +2145,7 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::LeftAnti, JoinType::RightAnti, JoinType::LeftMark, + JoinType::RightMark, ]; let default_partition_count = SessionConfig::new().target_partitions(); @@ -2178,7 +2179,8 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::Inner | JoinType::Right | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { let right_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![ Arc::new(Column::new_with_schema("c2_c1", &join_schema)?), Arc::new(Column::new_with_schema("c2_c2", &join_schema)?), diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 82ee73b525..1a8064ac1e 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -305,6 +305,31 @@ async fn test_left_mark_join_1k_filtered() { .await } +// todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support +#[tokio::test] +async fn test_right_mark_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightMark, + None, + ) + .run_test(&[NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_mark_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj], false) + .await +} + type JoinFilterBuilder = Box<dyn Fn(Arc<Schema>, Arc<Schema>) -> JoinFilter>; struct JoinFuzzTestCase { diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 7054500865..fd84776312 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -627,7 +627,7 @@ fn multi_hash_joins() -> Result<()> { test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; test_config.run(&expected, top_join, &SORT_DISTRIB_DISTRIB)?; } - JoinType::RightSemi | JoinType::RightAnti => {} + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {} } match join_type { @@ -636,7 +636,8 @@ fn multi_hash_joins() -> Result<()> { | JoinType::Right | JoinType::Full | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { // This time we use (b1 == c) for top join // Join on (b1 == c) let top_join_on = vec![( diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 533e81e64f..fbd1fdadc4 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1623,6 +1623,10 @@ pub fn build_join_schema( .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect() } + JoinType::RightMark => right_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain(once(mark_field(left))) + .collect(), }; let func_dependencies = left.functional_dependencies().join( right.functional_dependencies(), diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 0c30c97857..d8d6739b0e 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -310,7 +310,10 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { check_inner_plan(left)?; check_no_outer_references(right) } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => { check_no_outer_references(left)?; check_inner_plan(right) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 208a8510c3..4ac2d182aa 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -556,7 +556,9 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left.head_output_expr() } - JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.head_output_expr() + } }, LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { static_term.head_output_expr() @@ -1340,7 +1342,9 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left.max_rows() } - JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.max_rows() + } }, LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), LogicalPlan::Union(Union { inputs, .. }) => { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index d0457e7090..33af52824a 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -672,10 +672,10 @@ fn outer_columns_helper_multi<'a, 'b>( /// Depending on the join type, it divides the requirement indices into those /// that apply to the left child and those that apply to the right child. /// -/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split -/// between left and right children. The right child indices are adjusted to -/// point to valid positions within the right child by subtracting the length -/// of the left child. +/// - For `INNER`, `LEFT`, `RIGHT`, `FULL`, `LEFTMARK`, and `RIGHTMARK` joins, +/// the requirements are split between left and right children. The right +/// child indices are adjusted to point to valid positions within the right +/// child by subtracting the length of the left child. /// /// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all /// requirements are re-routed to either the left child or the right child @@ -704,7 +704,8 @@ fn split_join_requirements( | JoinType::Left | JoinType::Right | JoinType::Full - | JoinType::LeftMark => { + | JoinType::LeftMark + | JoinType::RightMark => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: indices.split_off(left_len) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 1c1996d6a2..7c4a026788 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -168,7 +168,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false), // No columns from the left side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => (false, true), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true), } } @@ -191,6 +191,7 @@ pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftAnti => (false, true), JoinType::RightAnti => (true, false), JoinType::LeftMark => (false, true), + JoinType::RightMark => (true, false), } } @@ -691,7 +692,7 @@ fn infer_join_predicates_from_on_filters( inferred_predicates, ) } - JoinType::Right | JoinType::RightSemi => { + JoinType::Right | JoinType::RightSemi | JoinType::RightMark => { infer_join_predicates_impl::<false, true>( join_col_keys, on_filters, diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 0e722f3dc8..8af6f3be03 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -791,7 +791,9 @@ impl EquivalenceGroup { result } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), - JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right_equivalences.clone() + } }; Ok(group) } diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 478ce39eec..566cf2f3a2 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -334,7 +334,7 @@ pub fn adjust_input_keys_ordering( left.schema().fields().len(), ) .unwrap_or_default(), - JoinType::RightSemi | JoinType::RightAnti => { + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { requirements.data.clone() } JoinType::Left diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index bd7b0060c3..a9c0e4cb28 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -488,7 +488,8 @@ fn expr_source_side( | JoinType::Left | JoinType::Right | JoinType::Full - | JoinType::LeftMark => { + | JoinType::LeftMark + | JoinType::RightMark => { let eq_group = eqp.eq_group(); let mut right_ordering = ordering.clone(); let (mut valid_left, mut valid_right) = (true, true); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 96cb09b4cb..5034a199e2 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -475,6 +475,7 @@ impl HashJoinExec { | JoinType::Right | JoinType::RightAnti | JoinType::RightSemi + | JoinType::RightMark ), ] } @@ -556,7 +557,8 @@ impl HashJoinExec { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right - | JoinType::RightAnti => EmissionType::Incremental, + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, // If we need to generate unmatched rows from the *build side*, // we need to emit them at the end. JoinType::Left @@ -1561,15 +1563,27 @@ impl HashJoinStream { self.right_side_ordered, )?; - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Left, - )?; + let result = if self.join_type == JoinType::RightMark { + build_batch_from_indices( + &self.schema, + &state.batch, + build_side.left_data.batch(), + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Right, + )? + } else { + build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Left, + )? + }; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(result.num_rows()); @@ -3331,6 +3345,95 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn join_right_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::RightMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a2 | b1 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 70 | true |", + "| 20 | 5 | 80 | true |", + "| 30 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_join_right_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::RightMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a2 | b1 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 60 | true |", + "| 20 | 4 | 70 | true |", + "| 30 | 5 | 80 | true |", + "| 40 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[test] fn join_with_hash_collision() -> Result<()> { let mut hashmap_left = HashTable::with_capacity(4); @@ -3759,6 +3862,15 @@ mod tests { "| 3 | 7 | 9 | false |", "+----+----+----+-------+", ]; + let expected_right_mark = vec![ + "+----+----+----+-------+", + "| a2 | b2 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 70 | true |", + "| 20 | 5 | 80 | true |", + "| 30 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; let test_cases = vec![ (JoinType::Inner, expected_inner), @@ -3770,6 +3882,7 @@ mod tests { (JoinType::RightSemi, expected_right_semi), (JoinType::RightAnti, expected_right_anti), (JoinType::LeftMark, expected_left_mark), + (JoinType::RightMark, expected_right_mark), ]; for (join_type, expected) in test_cases { @@ -4049,6 +4162,7 @@ mod tests { JoinType::RightSemi, JoinType::RightAnti, JoinType::LeftMark, + JoinType::RightMark, ]; for join_type in join_types { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index c9b5b9e43b..44021c38a7 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -274,7 +274,8 @@ impl NestedLoopJoinExec { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right - | JoinType::RightAnti => EmissionType::Incremental, + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, // If we need to generate unmatched rows from the *build side*, // we need to emit them at the end. JoinType::Left @@ -1009,15 +1010,30 @@ fn join_left_and_right_batch( right_side_ordered, )?; - build_batch_from_indices( - schema, - left_batch, - right_batch, - &left_side, - &right_side, - column_indices, - JoinSide::Left, - ) + // Switch around the build side and probe side for `JoinType::RightMark` + // because in a RightMark join, we want to mark rows on the right table + // by looking for matches in the left. + if join_type == JoinType::RightMark { + build_batch_from_indices( + schema, + right_batch, + left_batch, + &left_side, + &right_side, + column_indices, + JoinSide::Right, + ) + } else { + build_batch_from_indices( + schema, + left_batch, + right_batch, + &left_side, + &right_side, + column_indices, + JoinSide::Left, + ) + } } impl<T: BatchTransformer + Unpin + Send> Stream for NestedLoopJoinStream<T> { @@ -1460,6 +1476,36 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn join_right_mark_with_filter() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_left_table(); + let right = build_right_table(); + + let filter = prepare_join_filter(); + let (columns, batches) = multi_partitioned_join_collect( + left, + right, + &JoinType::RightMark, + Some(filter), + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]); + + assert_snapshot!(batches_to_sort_string(&batches), @r#" + +----+----+-----+-------+ + | a2 | b2 | c2 | mark | + +----+----+-----+-------+ + | 10 | 10 | 100 | false | + | 12 | 10 | 40 | false | + | 2 | 2 | 80 | true | + +----+----+-----+-------+ + "#); + + Ok(()) + } + #[tokio::test] async fn test_overallocation() -> Result<()> { let left = build_table( @@ -1488,6 +1534,7 @@ pub(crate) mod tests { JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, + JoinType::RightMark, ]; for join_type in join_types { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index eab3796b15..f361992caa 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -229,9 +229,11 @@ impl SortMergeJoinExec { // When output schema contains only the right side, probe side is right. // Otherwise probe side is the left side. match join_type { - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - JoinSide::Right - } + // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226) + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, JoinType::Inner | JoinType::Left | JoinType::Full @@ -249,7 +251,10 @@ impl SortMergeJoinExec { | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => vec![true, false], - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => { vec![false, true] } _ => vec![false, false], diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 95497f5093..84575acea5 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -770,7 +770,11 @@ fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> } else { matches!( join_type, - JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi + JoinType::Right + | JoinType::RightAnti + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightMark ) } } @@ -818,6 +822,20 @@ where .collect(); (build_indices, probe_indices) } + (JoinSide::Right, JoinType::RightMark) => { + let build_indices = (0..prune_length) + .map(L::Native::from_usize) + .collect::<PrimitiveArray<L>>(); + let probe_indices = (0..prune_length) + .map(|idx| { + // For mark join we output a dummy index 0 to indicate the row had a match + visited_rows + .contains(&(idx + deleted_offset)) + .then_some(R::Native::from_usize(0).unwrap()) + }) + .collect(); + (build_indices, probe_indices) + } // In the case of `Left` or `Right` join, or `Full` join, get the anti indices (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index f473f82ee2..c5f7087ac1 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -41,6 +41,7 @@ use arrow::array::{ BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; +use arrow::buffer::NullBuffer; use arrow::compute; use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, @@ -216,6 +217,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??) JoinType::RightAnti => false, // doesn't introduce nulls (or can it??) JoinType::LeftMark => false, + JoinType::RightMark => false, }; if force_nullable { @@ -282,6 +284,16 @@ pub fn build_join_schema( left_fields().chain(right_field).unzip() } JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), + JoinType::RightMark => { + let left_field = once(( + Field::new("mark", arrow_schema::DataType::Boolean, false), + ColumnIndex { + index: 0, + side: JoinSide::None, + }, + )); + right_fields().chain(left_field).unzip() + } }; let (schema1, schema2) = match join_type { @@ -509,6 +521,15 @@ fn estimate_join_cardinality( column_statistics, }) } + JoinType::RightMark => { + let num_rows = *right_stats.num_rows.get_value()?; + let mut column_statistics = right_stats.column_statistics; + column_statistics.push(ColumnStatistics::new_unknown()); + Some(PartialJoinStatistics { + num_rows, + column_statistics, + }) + } } } @@ -880,7 +901,7 @@ pub(crate) fn build_batch_from_indices( for column_index in column_indices { let array = if column_index.side == JoinSide::None { - // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false + // For mark joins, the mark column is a true if the indices is not null, otherwise it will be false Arc::new(compute::is_not_null(probe_indices)?) } else if column_index.side == build_side { let array = build_input_buffer.column(column_index.index); @@ -951,6 +972,12 @@ pub(crate) fn adjust_indices_by_join_type( // the left_indices will not be used later for the `right anti` join Ok((left_indices, right_indices)) } + JoinType::RightMark => { + let right_indices = get_mark_indices(&adjust_range, &right_indices); + let left_indices_vec: Vec<u64> = adjust_range.map(|i| i as u64).collect(); + let left_indices = UInt64Array::from(left_indices_vec); + Ok((left_indices, right_indices)) + } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop @@ -1052,17 +1079,7 @@ pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>( where NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; // get the anti index @@ -1081,19 +1098,8 @@ pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>( where NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; - // get the semi index (range) .filter_map(|idx| { @@ -1102,6 +1108,37 @@ where .collect() } +pub(crate) fn get_mark_indices<T: ArrowPrimitiveType>( + range: &Range<usize>, + input_indices: &PrimitiveArray<T>, +) -> PrimitiveArray<UInt32Type> +where + NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>, +{ + let mut bitmap = build_range_bitmap(range, input_indices); + PrimitiveArray::new( + vec![0; range.len()].into(), + Some(NullBuffer::new(bitmap.finish())), + ) +} + +fn build_range_bitmap<T: ArrowPrimitiveType>( + range: &Range<usize>, + input: &PrimitiveArray<T>, +) -> BooleanBufferBuilder { + let mut builder = BooleanBufferBuilder::new(range.len()); + builder.append_n(range.len(), false); + + input.iter().flatten().for_each(|v| { + let idx = v.as_usize(); + if range.contains(&idx) { + builder.set_bit(idx - range.start, true); + } + }); + + builder +} + /// Appends probe indices in order by considering the given build indices. /// /// This function constructs new build and probe indices by iterating through @@ -1277,7 +1314,9 @@ pub(crate) fn symmetric_join_output_partitioning( JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left_partitioning.clone() } - JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right_partitioning.clone() + } JoinType::Inner | JoinType::Right => { adjust_right_output_partitioning(right_partitioning, left_columns_len)? } @@ -1299,7 +1338,9 @@ pub(crate) fn asymmetric_join_output_partitioning( right.output_partitioning(), left.schema().fields().len(), )?, - JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.output_partitioning().clone() + } JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 35f41155fa..9eab33928a 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -85,6 +85,7 @@ enum JoinType { RIGHTSEMI = 6; RIGHTANTI = 7; LEFTMARK = 8; + RIGHTMARK = 9; } enum JoinConstraint { diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 1ac35742c7..0c593a36b8 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -3838,6 +3838,7 @@ impl serde::Serialize for JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", }; serializer.serialize_str(variant) } @@ -3858,6 +3859,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "RIGHTSEMI", "RIGHTANTI", "LEFTMARK", + "RIGHTMARK", ]; struct GeneratedVisitor; @@ -3907,6 +3909,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "RIGHTSEMI" => Ok(JoinType::Rightsemi), "RIGHTANTI" => Ok(JoinType::Rightanti), "LEFTMARK" => Ok(JoinType::Leftmark), + "RIGHTMARK" => Ok(JoinType::Rightmark), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index a55714f190..c051dd00f7 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -904,6 +904,7 @@ pub enum JoinType { Rightsemi = 6, Rightanti = 7, Leftmark = 8, + Rightmark = 9, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -921,6 +922,7 @@ impl JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -935,6 +937,7 @@ impl JoinType { "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), "LEFTMARK" => Some(Self::Leftmark), + "RIGHTMARK" => Some(Self::Rightmark), _ => None, } } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index a55714f190..c051dd00f7 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -904,6 +904,7 @@ pub enum JoinType { Rightsemi = 6, Rightanti = 7, Leftmark = 8, + Rightmark = 9, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -921,6 +922,7 @@ impl JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -935,6 +937,7 @@ impl JoinType { "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), "LEFTMARK" => Some(Self::Leftmark), + "RIGHTMARK" => Some(Self::Rightmark), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 1b5527c14a..162ed7ae25 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -205,6 +205,7 @@ impl From<protobuf::JoinType> for JoinType { protobuf::JoinType::Leftanti => JoinType::LeftAnti, protobuf::JoinType::Rightanti => JoinType::RightAnti, protobuf::JoinType::Leftmark => JoinType::LeftMark, + protobuf::JoinType::Rightmark => JoinType::RightMark, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index f931560403..814cf52904 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -682,6 +682,7 @@ impl From<JoinType> for protobuf::JoinType { JoinType::LeftAnti => protobuf::JoinType::Leftanti, JoinType::RightAnti => protobuf::JoinType::Rightanti, JoinType::LeftMark => protobuf::JoinType::Leftmark, + JoinType::RightMark => protobuf::JoinType::Rightmark, } } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index f667761703..d9f9767ba9 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -714,7 +714,8 @@ impl Unparser<'_> { | JoinType::LeftAnti | JoinType::LeftMark | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { let mut query_builder = QueryBuilder::default(); let mut from = TableWithJoinsBuilder::default(); let mut exists_select: SelectBuilder = SelectBuilder::default(); @@ -738,7 +739,8 @@ impl Unparser<'_> { let negated = match join.join_type { JoinType::LeftSemi | JoinType::RightSemi - | JoinType::LeftMark => false, + | JoinType::LeftMark + | JoinType::RightMark => false, JoinType::LeftAnti | JoinType::RightAnti => true, _ => unreachable!(), }; @@ -746,13 +748,25 @@ impl Unparser<'_> { subquery: Box::new(query_builder.build()?), negated, }; - if join.join_type == JoinType::LeftMark { - let (table_ref, _) = right_plan.schema().qualified_field(0); - let column = self - .col_to_sql(&Column::new(table_ref.cloned(), "mark"))?; - select.replace_mark(&column, &exists_expr); - } else { - select.selection(Some(exists_expr)); + + match join.join_type { + JoinType::LeftMark | JoinType::RightMark => { + let source_schema = + if join.join_type == JoinType::LeftMark { + right_plan.schema() + } else { + left_plan.schema() + }; + let (table_ref, _) = source_schema.qualified_field(0); + let column = self.col_to_sql(&Column::new( + table_ref.cloned(), + "mark", + ))?; + select.replace_mark(&column, &exists_expr); + } + _ => { + select.selection(Some(exists_expr)); + } } if let Some(projection) = left_projection { select.projection(projection); @@ -1244,7 +1258,9 @@ impl Unparser<'_> { JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint), JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint), JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint), - JoinType::LeftMark => unimplemented!("Unparsing of Left Mark join type"), + JoinType::LeftMark | JoinType::RightMark => { + unimplemented!("Unparsing of Mark join type") + } }) } diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs index 881157dcfa..fab43a5ff4 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs @@ -132,6 +132,7 @@ fn from_substrait_jointype(join_type: i32) -> datafusion::common::Result<JoinTyp join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), + join_rel::JoinType::RightMark => Ok(JoinType::RightMark), _ => plan_err!("unsupported join type {substrait_join_type:?}"), } } else { diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs index 79564ad5da..65c3e426d2 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/join.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -113,6 +113,7 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { JoinType::LeftAnti => join_rel::JoinType::LeftAnti, JoinType::LeftSemi => join_rel::JoinType::LeftSemi, JoinType::LeftMark => join_rel::JoinType::LeftMark, + JoinType::RightMark => join_rel::JoinType::RightMark, JoinType::RightAnti | JoinType::RightSemi => { unimplemented!() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org